第一次没有用模板写树链剖分,感觉爽爽的。‘
题意:根节点为1的树上有三种操作:
1.将u和它的所有孩子节点置1 ;2.将u和它的所有祖先置0 ;3.查询u节点是0还是1
思路:一操作相当于将u和它的子节点这个区间所有的值进行更新,在代码中显示为[ id[u],R[u] ]。二操作相当于更新1到u的这条链上的节点的值,不过不一定在一个连续的区间内,所以需要树链剖分,分成很多连续的子区间,进行剖分。3,直接查询
简单说明一下我理解的剖分的过程:我们在dfs2的编号时候是按照先重边上的点先编号,即重边上的那个点的编号是父节点的编号+1,所以重链上的点的编号都是连续的。剖分的时候就是将区间分成若干个重链,然后进行维护,因为他们编号连续。
#include<bits/stdc++.h> using namespace std; const int maxn = 5e5 + 10; typedef long long ll; #define clr(x,y) memset(x,y,sizeof x) #define INF 0x3f3f3f3f int len,head[maxn]; struct Edge{int to,next;}edge[maxn << 1]; int edge_num; int num,dep[maxn],son[maxn],siz[maxn],fa[maxn],top[maxn],id[maxn]; int R[maxn]; void add_edge(int x,int y) { edge[edge_num] = (Edge){y,head[x]};head[x] = edge_num ++; } void Init() { edge_num = 0;clr(head,-1);num = 0; } void dfs1(int u,int pre,int d) { cout << u << endl; son[u] = 0;dep[u] = d;siz[u] = 1;fa[u] = pre; for(int i = head[u];i != -1;i = edge[i].next) { int v = edge[i].to; if(v == pre)continue; dfs1(v,u,d +1); siz[u] += siz[v]; if(siz[v] > siz[son[u]]) son[u] = v; } } void dfs2(int u,int tp) { id[u] = ++ num;top[u] = tp; if(son[u] != 0) dfs2(son[u],tp); for(int i = head[u];i != -1;i = edge[i].next) { int v = edge[i].to; if(v == fa[u] || v == son[u])continue; dfs2(v,v); } R[u] = num; } int tree[maxn << 2]; void build(int l,int r,int rt) { tree[rt] = 0; if(l == r) { return ; } int mid = (l + r) >>1; build(l,mid,rt << 1);build(mid + 1,r,rt <<1|1); } void update(int L,int R,int x,int l,int r,int rt) { if(tree[rt] == x)return; if(L <= l && R >= r) { tree[rt] = x;return; } if(tree[rt] != -1) { tree[rt << 1] = tree[rt << 1|1] = tree[rt]; tree[rt] = -1; } int mid = (l + r) >> 1; if(L <= mid) update(L,R,x,l,mid,rt << 1); if(R >= mid + 1) update(L,R,x,mid + 1,r,rt << 1|1); } int query(int pos,int l,int r,int rt) { if(l == r) { return tree[rt]; } if(tree[rt] != -1) { tree[rt << 1] = tree[rt << 1|1] = tree[rt]; tree[rt] = -1; } int mid = (l + r) >> 1; if(pos <= mid) return query(pos,l,mid,rt << 1); else return query(pos,mid + 1,r,rt << 1|1); } void updat(int u,int v,int x) { int f1 = top[u],f2 = top[v]; while(f1 != f2) { if(dep[f1] < dep[f2]) { swap(f1,f2);swap(u,v); } update(id[f1],id[u],x,1,num,1); u = fa[f1];f1 = top[u]; } if(dep[u] > dep[v])swap(u,v); update(id[u],id[v],x,1,num,1); } int main() { int n; while( ~ scanf("%d",&n)) { Init(); for(int i = 1;i <= n - 1;i ++) { int x,y;scanf("%d%d",&x,&y); add_edge(x,y);add_edge(y,x); } dfs1(1,1,1); dfs2(1,1); build(1,num,1); int m;scanf("%d",&m); while(m --) { int x,y;scanf("%d%d",&x,&y); if(x == 1) { update(id[y],R[y],1,1,num,1); } else if(x == 2) { updat(1,y,0); } else { printf("%d\n",query(id[y],1,num,1)); } } } return 0; }