树上莫队小结

树上莫队首先是考虑将树上的问题转移成序列上的问题
考虑括号序,一个点在dfs进入它和退出去时分别记一次。这样,考虑原树中的一条路径(u,v)可以被在序列中表示成什么样子:设u的括号序更小,那么如果u是LCA(u,v),那么路径可以被表示成(st[u],st[v]),否则路径可以被表示成(ed[u],st[v]),无论怎样表示,我们都只统计出现在序列中出现奇数次的元素的信息。注意这样的话第2种情况我们并没有统计LCA的信息,所以需要单独判断。然后因为括号序的长度是原树大小的2倍,所以数组记得开大。

例题:
bzoj3757

比较裸的一道题目(笔者一开始妄图树上差分,但是没法去重)

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+5;
inline int read(){
	char c=getchar();int t=0,f=1;
	while((!isdigit(c))&&(c!=EOF)){if(c=='-')f=-1;c=getchar();}
	while((isdigit(c))&&(c!=EOF)){t=(t<<3)+(t<<1)+(c^48);c=getchar();}
	return t*f;
}
int n,m,a[maxn];
struct edge{
	int v,p;
}e[maxn<<1];
int h[maxn],cnt;
inline void add(int a,int b){
	e[++cnt].p=h[a];
	e[cnt].v=b;
	h[a]=cnt;
	e[++cnt].p=h[b];
	e[cnt].v=a;
	h[b]=cnt;
}
int st[maxn],ed[maxn],pos[maxn],tot,dep[maxn];
int f[maxn][20];
void dfs1(int u,int fa){
	tot++;
	st[u]=tot;pos[tot]=u;dep[u]=dep[fa]+1;
	f[u][0]=fa;
	for(int i=1;f[u][i-1];i++)f[u][i]=f[f[u][i-1]][i-1];
	for(int i=h[u];i;i=e[i].p){
		int v=e[i].v;
		if(v==fa)continue;
		dfs1(v,u);
	}
	ed[u]=++tot;pos[tot]=u;
}
struct node{
	int l,r,pos,a,b,lca,id;
}q[maxn];
int lim;
int out[maxn];
bool cmp(node a,node b){
	if(a.pos==b.pos){
		if(a.pos&1)return a.r<b.r;
		return a.r>b.r;
	}
	return a.pos<b.pos;
}
int col[maxn],vis[maxn],ans;
void add(int x){
	vis[pos[x]]++;
	if(vis[pos[x]]&1){
		col[a[pos[x]]]++;
		if(col[a[pos[x]]]==1&&pos[x]){
			//printf("%d %d !\n",x,pos[x]);
			ans++;
		}
	}
	else{
		col[a[pos[x]]]--;
		if(col[a[pos[x]]]==0&&pos[x]){
			//printf("%d %d ?\n",x,pos[x]);
			ans--;
		}
	}
}
void remove(int x){
	vis[pos[x]]--;
	if(vis[pos[x]]&1){
		col[a[pos[x]]]++;
		if(col[a[pos[x]]]==1&&pos[x]){
			//printf("%d %d !\n",x,pos[x]);
			ans++;
		}
	}
	else{
		col[a[pos[x]]]--;
		if(col[a[pos[x]]]==0&&pos[x]){
			//printf("%d %d ?\n",x,pos[x]);
			ans--;
		}
	}
}
inline int lca(int a,int b){
	if(dep[a]>dep[b])swap(a,b);
	for(int i=19;i>=0;i--){
		if(dep[a]<=dep[f[b][i]])b=f[b][i];
	}
	if(a==b)return a;
	for(int i=19;i>=0;i--){
		if(f[a][i]!=f[b][i]){a=f[a][i];b=f[b][i];}
	}
	return f[a][0];
}
signed main(){
	n=read(),m=read();
	for(int i=1;i<=n;i++)a[i]=read();
	for(int i=1;i<=n;i++){
		int a=read(),b=read();
		add(a,b);
	}
	dfs1(0,0);
	lim=sqrt(2*m);
	for(int i=1;i<=m;i++){
		int x=read(),y=read();
		if(st[x]>st[y])swap(x,y);
		q[i].lca=lca(x,y);
		//printf("%d %d %d %d %d\n",x,q[i].lca,y,st[x],st[y]);
		if(q[i].lca==x)
		q[i].l=st[x];
		else
		q[i].l=ed[x];
		q[i].r=st[y],q[i].a=read(),q[i].b=read();
		q[i].pos=q[i].l/lim+1;q[i].id=i;
	}
	sort(q+1,q+1+m,cmp);
	int l=1,r=1;
	/*for(int i=1;i<=tot;i++)
	printf("%d ",pos[i]);
	puts("");*/
	for(int i=1;i<=m;i++){
		while(r<q[i].r)add(++r);
		while(r>q[i].r)remove(r--);
		while(l<q[i].l)remove(l++);
		while(l>q[i].l)add(--l);
		out[q[i].id]=ans;
		if(!col[a[q[i].lca]]){out[q[i].id]++;}
		if(q[i].a!=q[i].b){
			col[a[q[i].lca]]++;
			if(col[q[i].a]&&col[q[i].b])out[q[i].id]--;
			col[a[q[i].lca]]--;
		}
		/*for(int i=1;i<=n;i++)
		printf("%d ",vis[i]);
		puts("");
		printf("%d %d %d %d %d\n",ans,q[i].a,q[i].b,q[i].l,q[i].r);*/
	}
	for(int i=1;i<=m;i++)printf("%d\n",out[i]);
	return 0;
}

发布了95 篇原创文章 · 获赞 9 · 访问量 3179

猜你喜欢

转载自blog.csdn.net/wmhtxdy/article/details/104680855