bzoj 4771: 七彩树

4771: 七彩树

Time Limit: 5 Sec  Memory Limit: 256 MB
Submit: 1334  Solved: 388
[Submit][Status][Discuss]

Description

给定一棵n个点的有根树,编号依次为1到n,其中1号点是根节点。每个节点都被染上了某一种颜色,其中第i个节
点的颜色为c[i]。如果c[i]=c[j],那么我们认为点i和点j拥有相同的颜色。定义depth[i]为i节点与根节点的距离
,为了方便起见,你可以认为树上相邻的两个点之间的距离为1。站在这棵色彩斑斓的树前面,你将面临m个问题。
每个问题包含两个整数x和d,表示询问x子树里且depth不超过depth[x]+d的所有点中出现了多少种本质不同的颜色
。请写一个程序,快速回答这些询问。

Input

第一行包含一个正整数T(1<=T<=500),表示测试数据的组数。
每组数据中,第一行包含两个正整数n(1<=n<=100000)和m(1<=m<=100000),表示节点数和询问数。
第二行包含n个正整数,其中第i个数为c[i](1<=c[i]<=n),分别表示每个节点的颜色。
第三行包含n-1个正整数,其中第i个数为f[i+1](1<=f[i]<i),表示节点i+1的父亲节点的编号。
接下来m行,每行两个整数x(1<=x<=n)和d(0<=d<n),依次表示每个询问。
输入数据经过了加密,对于每个询问,如果你读入了x和d,那么真实的x和d分别是x xor last和d xor last,
其中last表示这组数据中上一次询问的答案,如果这是当前数据的第一组询问,那么last=0。
输入数据保证n和m的总和不超过500000。

Output

对于每个询问输出一行一个整数,即答案。

Sample Input

1
5 8
1 3 3 2 2
1 1 3 3
1 0
0 0
3 0
1 3
2 1
2 0
6 2
4 1

Sample Output

1
2
3
1
1
2
1
1

HINT

 

Source

 
    非常经典的一道题,相当于是把序列上的区间颜色数 放到了树上,并且是加强版。
    考虑序列上的区间颜色数求法,(只考虑在线的话)我们是用可持久化线段树实现的每种颜色不会重复加;
    如果只是简单的放到树上,没有深度限制的话,直接通过dfs序转换成序列上的求就行了,也很简单。。。。
   
    所以这个深度限制是什么鬼啊QWQ
    这个时候只能换一种思路啦。
    考虑把节点按照深度从小到大一层一层加进去,我们加入一个节点的时候要保证,此时以任意一个节点为根的子树中,这种颜色只会被算一次。
    这该怎么做呢????
 
    当这个节点x是这种颜色第一个加入的节点,那么显然只需要在这个节点的dfs序上+1就行了;
    如果有其他的节点,因为我们已经保证了不考虑这个节点的时候的任意子树,每种颜色只会被算一次,所以我们只需要考虑  子树中包含这个节点x 并且 原来已经有这种颜色的 点。设pre是该颜色的点中dfs序小于x的dfs序最大的节点,nxt是。。。。大于x的。。。最小的点,那么LCA(x,nxt)--,LCA(x,pre),LCA(pre,nxt)++就行啦(请自行画图草稿)
 
#include<bits/stdc++.h>
#define ll long long
using namespace std;
#define pb push_back
#define mid (l+r>>1)
const int maxn=100005;
set<int> s[maxn];
vector<int> g[maxn];
set<int> :: iterator it;
int dfn[maxn],dy[maxn],siz[maxn],son[maxn],L;
int n,m,to[maxn],ne[maxn],hd[maxn],num,col[maxn];
int F[maxn],T,cl[maxn],dc,le,ri,w,M,dep[maxn],ans;
struct node{
	int S;
	node *lc,*rc;
}nil[maxn*123],*rot[maxn],*cnt;

inline int read(){
	int x=0; char ch=getchar();
	for(;ch<'0'||ch>'9';ch=getchar());
	for(;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
	return x;
}
void W(int x){ if(x>=10) W(x/10); putchar(x%10+'0');}

inline void init(){
	fill(hd+1,hd+n+1,0),num=0;
	for(int i=1;i<=M;i++) g[i].clear(); 
	fill(son+1,son+n+1,0),M=0;dc=0;
	cnt=nil->lc=nil->rc=rot[0]=nil;
	for(int i=1;i<=n;i++) s[i].clear();
}

inline void add(int x,int y){ to[++num]=y,ne[num]=hd[x],hd[x]=num;}

void Fdfs(int x){
	siz[x]=1,M=max(M,dep[x]),g[dep[x]].pb(x);
	for(int i=hd[x];i;i=ne[i]){
		dep[to[i]]=dep[x]+1,Fdfs(to[i]);
		siz[x]+=siz[to[i]];
		if(!son[x]||siz[to[i]]>siz[son[x]]) son[x]=to[i];
	}
}

void Sdfs(int x,int tp){
	cl[x]=tp,dfn[x]=++dc,dy[dc]=x;
	if(!son[x]) return;
	
	Sdfs(son[x],tp);
	
	for(int i=hd[x];i;i=ne[i]) if(to[i]!=son[x]) Sdfs(to[i],to[i]); 
}

int LCA(int x,int y){
	while(cl[x]!=cl[y]){
		if(dep[cl[x]]>dep[cl[y]]) x=F[cl[x]];
		else y=F[cl[y]];
	}
	
	return dep[x]>dep[y]?y:x;
}

node *update(node *u,int l,int r){
	node *ret=++cnt;
	*ret=*u,ret->S+=w;
	if(l==r) return ret;
	
	if(le<=mid) ret->lc=update(ret->lc,l,mid);
	else ret->rc=update(ret->rc,mid+1,r);
	
	return ret;
}

void query(node *u,int l,int r){
	if(l>=le&&r<=ri){ ans+=u->S; return;}
	if(le<=mid) query(u->lc,l,mid);
	if(ri>mid) query(u->rc,mid+1,r);
}

inline void solve(){
	dep[1]=1,Fdfs(1),Sdfs(1,1);
	
	for(int i=1;i<=M;i++){
		rot[i]=rot[i-1];
		
	    for(int j=g[i].size()-1,now;j>=0;j--){
	    	now=g[i][j];
	    	le=dfn[now],w=1,rot[i]=update(rot[i],1,n);
	    	
	    	if(s[col[now]].size()){
	    		it=s[col[now]].upper_bound(dfn[now]);
	    		if(it==s[col[now]].end()) le=dfn[LCA(now,dy[*(--it)])],w=-1,rot[i]=update(rot[i],1,n);
	    		else if(it==s[col[now]].begin()) le=dfn[LCA(now,dy[*it])],w=-1,rot[i]=update(rot[i],1,n);
	    		else{
	    			int nxt=dy[*it],pre=dy[*(--it)];
	    			le=dfn[LCA(now,nxt)],w=-1,rot[i]=update(rot[i],1,n);
	    			le=dfn[LCA(now,pre)],w=-1,rot[i]=update(rot[i],1,n);
	    			le=dfn[LCA(pre,nxt)],w=1,rot[i]=update(rot[i],1,n);
				}
			}
			
			s[col[now]].insert(dfn[now]);
		}
	}
	
	while(m--){
		le=read()^ans,L=read()^ans,ans=0;
		L=min(M,dep[le]+L),ri=dfn[le]+siz[le]-1,le=dfn[le];
		query(rot[L],1,n),W(ans),puts("");
	}
}

int main(){
//	freopen("data.in","r",stdin);
//	freopen("data.out","w",stdout);
	
	T=read();
	
	while(T--){
		init(),n=read(),m=read();
		for(int i=1;i<=n;i++) col[i]=read();
		for(int i=2;i<=n;i++) F[i]=read(),add(F[i],i);
		solve();
	}
	
	return 0;
}

  

猜你喜欢

转载自www.cnblogs.com/JYYHH/p/9062645.html