题意
给一颗树,树上每个节点都有一个开关,每次操作对一个节点及其子树上节点的开关翻转,询问以某个节点为根的子树上开着的开关数目。
题解
dfs序将树形结构转换成线性结构,然后用线段树+lazy标记维护即可。
需要注意的就是编号问题,因为原树中的编号和线段树维护的区间编号不是一一对应的。
dfs序的pre编号是原树编号到线段树编号的一个hash。
所以在建树初始化的时候有一下两种方法:
- 单点更新,对线段树中的pre[i]更新。
- 建树时直接初始化,线段树叶子节点对应的树上编号为hashback[i]。
不要搞混。
代码
#include<bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef unsigned long long ull;
const int nmax = 1e6+7;
const int INF = 0x3f3f3f3f;
const ll LINF = 0x3f3f3f3f3f3f3f3f;
const ull p = 67;
const ull MOD = 1610612741;
int pre[nmax],suf[nmax],head[nmax],sz[nmax],hashback[nmax],dfsnum,tot;
int status[nmax];
struct node{int to,nxt;}e[nmax<<1];
struct treenode{
int l,r,val;bool flip;
int mid() {return (l+r)>>1;}
}tree[nmax<<2];
vector<int> v[nmax];
void add_edge(int u, int v){ e[tot].to = v, e[tot].nxt = head[u], head[u] = tot++;}
void dfs(int u, int f){
pre[u] = ++dfsnum; sz[u] = 1,hashback[dfsnum] = u;
for(int i = 0;i<v[u].size();++i) dfs(v[u][i],u),sz[u] += sz[v[u][i]];
suf[u] = dfsnum;
}
void pushup(int rt){
tree[rt].val = tree[rt<<1].val + tree[rt<<1|1].val;
}
void pushdown(int rt){
if(tree[rt].flip){
tree[rt<<1].val = (tree[rt<<1].r - tree[rt<<1].l + 1) - tree[rt<<1].val;
tree[rt<<1|1].val = (tree[rt<<1|1].r - tree[rt<<1|1].l + 1) - tree[rt<<1|1].val;
tree[rt<<1].flip ^= 1;
tree[rt<<1|1].flip ^= 1;
tree[rt].flip = 0;
}
}
void build(int l, int r, int rt){
tree[rt].l = l, tree[rt].r = r;
tree[rt].flip = tree[rt].val = 0;
if(tree[rt].l == tree[rt].r){
tree[rt].val = status[hashback[tree[rt].l]];
return;
};
build(l,tree[rt].mid(),rt<<1);
build(tree[rt].mid()+1,r,rt<<1|1);
pushup(rt);
}
void update(int l, int r, int rt){
if(l <= tree[rt].l && tree[rt].r <= r){
tree[rt].flip ^= 1;
tree[rt].val = (tree[rt].r - tree[rt].l + 1) - tree[rt].val;
return;
}
pushdown(rt);
if(r <= tree[rt].mid()) update(l,r,rt<<1);
else if(l>tree[rt].mid()) update(l,r,rt<<1|1);
else update(l,tree[rt].mid(),rt<<1),update(tree[rt].mid()+1,r,rt<<1|1);
pushup(rt);
}
int query(int l, int r, int rt){
if(tree[rt].l >= l && tree[rt].r <= r ) return tree[rt].val;
pushdown(rt);
if(r <= tree[rt].mid()) return query(l,r,rt<<1);
else if(l > tree[rt].mid()) return query(l,r,rt<<1|1);
else return query(l,tree[rt].mid(),rt<<1) + query(tree[rt].mid()+1,r,rt<<1|1) ;
}
int n,m;
int main() {
scanf("%d",&n);
int fa = 0;
for(int u = 2;u<=n;++u) scanf("%d",&fa), v[fa].push_back(u);
for(int i = 1;i<=n;++i) scanf("%d",&status[i]);
dfs(1,-1);
build(1,n,1);
scanf("%d",&m); char op[20]; int pos;
for(int i = 1;i<=m;++i){
scanf("%s %d",op,&pos);
if(op[0] == 'g') printf("%d\n",query(pre[pos],suf[pos],1));
else update(pre[pos],suf[pos],1);
}
return 0;
}