4.3 省选模拟赛 采蘑菇 点分治

给出一棵树 每个点都有一个颜色ci 问 从i号点出发到任意一个点的路径上本质不同的颜色之和。

\(n\leq 300000\)

光线性扫描时不行的 显然有\(n^2\)的暴力。

考虑树是一条链的时候怎么做? 可以发现先求出1的答案然后维护换根的过程 记录每个点颜色的pre 前驱 nex后继很容易通过分类讨论得到答案。

考虑树的时候怎么做?还是维护换根的过程 当两个点颜色相同的时候 答案显然一样,当不同的时候 可以分析得出 要查自己子树内没有自己父亲颜色的祖先的节点个数。还要查 自己子树之外没有自己的颜色所庇护的节点个数。

第一个很好查 预处理一下祖先庇护了哪写节点 经过祖先的话就下方这些节点 通过线段树维护dfs序很容易区间求和。

考虑第二个由于在不断换根的过程中 有很多节点可以庇护自己子树之外的节点 这个是存在轮换关系的 但是考虑和上一步的下方关系并不容易合并 或者说合并非常的困难所以这个做法是不成立的。

考试的时候我只是关心了 一下换根第一个步应该怎么做 却没有长远的眼光看到换到若干步之后思路的错误。

果然 换根主要解决的问题 是自己父亲那边的处理问题 这点尤其重要。

只能换个思路了 。事实上对于树上路径信息统计问题 或者说对于第一步暴力的优化 有一个非常有效的做法 点分治。

我们可以尝试利用点分治来维护刚才的暴力的过程 从而求解答案。

具体过程:考虑颜色数较少的情况 对于每一种颜色单独点分治一次。

可以发现对于非当前分治重心的点来说 对于当前颜色 dfs统计一下其到分治重心的路径上出现这种颜色了没有 如果出现了 显然这种颜色的贡献为分治大小-这个点在当前分治重心下的子树大小 考虑没有出现 显然是其他子树内点到到分治重心出现这种颜色的点的个数-自己子树中的这样的点的个数。

考虑分治重心 显然是所有子树内到分治重心出现这种颜色的点的个数。

可以发现这个做法 可以拓展到多种颜色上 在点分治的时候预处理一下第二种贡献再进行计算即可。

换根的代码(虽然是错误的 但是只是换根的时候父亲往上的那部分点没有办法在O(logn)的时间内处理罢了。

const int MAXN=300010;
int n,len,cnt,top,id;ll ans1[MAXN],ans;
int a[MAXN],root[MAXN],dfn[MAXN],sz[MAXN],c[MAXN];
int lin[MAXN],ver[MAXN<<1],nex[MAXN<<1];
struct wy{int l,r;int sum;}t[MAXN*30];
vector<int>g[MAXN];
inline void add(int x,int y)
{
	ver[++len]=y;
	nex[len]=lin[x];
	lin[x]=len;
}
inline void dfs(int x,int fa)
{
	sz[x]=1;dfn[x]=++cnt;
	int ww=c[a[x]];
	c[a[x]]=x;g[ww].pb(x);
	if(!ww)++ans;ans1[1]+=ans;
	go(x)
	{
		if(tn==fa)continue;
		dfs(tn,x);
		sz[x]+=sz[tn];
	}
	c[a[x]]=ww;if(!ww)--ans;
}
inline void change(int &p,int l,int r,int x,int w)
{
	if(!p)p=++id;
	if(l==r){sum(p)+=w;return;}
	int mid=(l+r)>>1;
	if(x<=mid)change(l(p),l,mid,x,w);
	else change(r(p),mid+1,r,x,w);
	sum(p)=sum(l(p))+sum(r(p));
}
inline int ask(int p,int l,int r,int L,int R)
{
	if(L>R)return 0;
	if(!p)return 0;
	if(L<=l&&R>=r)return sum(p);
	int mid=(l+r)>>1;
	if(L>mid)return ask(r(p),mid+1,r,L,R);
	if(R<=mid)return ask(l(p),l,mid,L,R);
	return ask(l(p),l,mid,L,R)+ask(r(p),mid+1,r,L,R);
}
inline void dp(int x,int fa)
{
	change(root[a[x]],1,n,dfn[x],-sz[x]);
	for(ui i=0;i<g[x].size();++i)
	{
		int tn=g[x][i];
		change(root[a[tn]],1,n,dfn[tn],sz[tn]);
	}
	go(x)
	{
		if(tn==fa)continue;
		if(a[tn]==a[x])ans1[tn]=ans1[x];
		else
		{
			ans1[tn]=ans1[x]-(sz[tn]-ask(root[a[x]],1,n,dfn[tn],dfn[tn]+sz[tn]-1));
			ans1[tn]=ans1[tn]+(n-sz[tn]-ask(root[a[tn]],1,n,1,dfn[tn]-1)-ask(root[a[tn]],1,n,dfn[tn]+sz[tn],n));
			//if(tn==3)put(ask(root[a[tn]],1,n,dfn[tn]+sz[tn],n));
		}
		dp(tn,x);
	}
	change(root[a[x]],1,n,dfn[x],sz[x]);
	for(ui i=0;i<g[x].size();++i)
	{
		int tn=g[x][i];
		change(root[a[tn]],1,n,dfn[tn],-sz[tn]);
	}
}
int main()
{
	freopen("1.in","r",stdin);
	freopen("2.out","w",stdout);
	get(n);
	rep(1,n,i)get(a[i]);
	rep(1,n-1,i)
	{
		int x,y;
		get(x);get(y);
		add(x,y);add(y,x);
	}
	dfs(1,0);//putl(ans1[1]);
	for(ui i=0;i<g[0].size();++i)
	{
		change(root[a[g[0][i]]],1,n,dfn[g[0][i]],sz[g[0][i]]);
		//put(g[0][i]);
	}
	dp(1,0);rep(1,n,i)putl(ans1[i]);return 0;
}

点分治的代码

const int MAXN=300010;
int n,len,cnt,id,top,maxx,rt,tg;
int a[MAXN],sz[MAXN],son[MAXN],vis[MAXN],c[MAXN];
int lin[MAXN],ver[MAXN<<1],nex[MAXN<<1];
ll ans1[MAXN],s[MAXN],ans,sum;
inline void add(int x,int y)
{
	ver[++len]=y;
	nex[len]=lin[x];
	lin[x]=len;
}
inline void get_root(int x,int fa)
{
	sz[x]=1;son[x]=0;
	go(x)
	{
		if(tn==fa||vis[tn])continue;
		get_root(tn,x);
		sz[x]+=sz[tn];
		son[x]=max(son[x],sz[tn]);
	}
	son[x]=max(son[x],maxx-sz[x]);
	if(son[x]<son[rt])rt=x;
}
inline void dfs(int x,int fa,int v)
{
	int ww=c[a[x]];
	if(!c[a[x]]){c[a[x]]=x;s[a[x]]+=sz[x]*v;ans+=sz[x]*v;}
	go(x)if(tn!=fa&&!vis[tn])dfs(tn,x,v);
	if(!ww)c[a[x]]=ww;
}
inline void dp(int x,int fa)
{
	int ww=c[a[x]];
	if(!c[a[x]]){c[a[x]]=x;sum+=cnt;ans-=s[a[x]];}
	ans1[x]+=sum+cnt;ans1[x]+=ans;
	go(x)if(tn!=fa&&!vis[tn])dp(tn,x);
	if(!ww){c[a[x]]=ww;sum-=cnt;ans+=s[a[x]];}
}
inline void solve(int x)
{
	get_root(x,0);c[a[x]]=x;vis[x]=1;
	go(x)if(!vis[tn])dfs(tn,x,1);
	ans1[x]+=sz[x]+ans;cnt=sz[x];
	go(x)
	{
		if(vis[tn])continue;
		dfs(tn,x,-1);cnt-=sz[tn];
		dp(tn,x);cnt+=sz[tn];
		dfs(tn,x,1);
	}
	go(x)if(!vis[tn])dfs(tn,x,-1);
	c[a[x]]=0;
	go(x)
	{
		if(vis[tn])continue;
		rt=0;maxx=sz[tn];
		get_root(tn,0);
		solve(rt);
	}
}
int main()
{
	freopen("1.in","r",stdin);
	freopen("2.out","w",stdout);
	get(n);
	rep(1,n,i)get(a[i]);
	rep(1,n-1,i)
	{
		int x,y;
		get(x);get(y);
		add(x,y);add(y,x);
	}
	rt=0;son[0]=n+1;maxx=n;get_root(1,0);
	solve(rt);rep(1,n,i)putl(ans1[i]);
	return 0;
}

其实核心就是分两步讨论。

猜你喜欢

转载自www.cnblogs.com/chdy/p/12627816.html
4.3