Codeforces1118F2. Tree Cutting (Hard Version)

题目

Solution

显而易见的是,两个同色点的 l c a lca 一定和这两个点划分在一棵树内,若 l c a lca 已经染色且颜色与这两点不同,那么答案很明显是 0 0
然后设计 d p dp 方程
f [ u ] [ 0 / 1 ] f[u][0/1] 表示 u u 的子树内(没有/有)有颜色的点的划分方案数
假设 u u 必须染成 c o l [ u ] col[u] (什么颜色都可以就是 0 0
u u 的子树内颜色与 u u 相同的点的个数为 n u m [ u ] num[u]
如果 u u 必须染某种颜色,那么 f [ u ] [ 1 ] = f a [ v ] = u f [ v ] [ v ] f[u][1]=\prod_{fa[v]=u} f[v][v必须染色]
如果 u u 可以染任何颜色,那么 f [ u ] [ 0 ] = f a [ v ] = u f [ v ] [ 0 ] f[u][0]=\prod_{fa[v]=u} f[v][0]
f [ u ] [ 1 ] = f a [ v ] = u f [ v ] [ 1 ] ( v f [ s o n [ u ] ] [ 0 ] ) f[u][1]=\sum_{fa[v]=u}f[v][1]*(除v外所有f[son[u]][0]的积)
如果 u u 可以染任何颜色或者它子树中与其颜色相同的个数与改颜色总数相同(这棵树可以分开来),那么 f [ u ] [ 0 ] + = f [ u ] [ 1 ] f[u][0]+=f[u][1]

Code

#include<bits/stdc++.h>
using namespace std;
const int N=300001,M=998244353;
#define FOR(v) for (int i=0,v;i<G[u].size();i++) if ((v=G[u][i])!=fa)
inline char gc(){
	static char buf[100000],*p1=buf,*p2=buf;
	return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline int rd(){
	int x=0,fl=1;char ch=gc();
	for (;ch<48||ch>57;ch=gc())if(ch=='-')fl=-1;
	for (;48<=ch&&ch<=57;ch=gc())x=(x<<3)+(x<<1)+(ch^48);
	return x*fl;
}
inline void wri(int a){if(a<0)a=-a,putchar('-');if(a>=10)wri(a/10);putchar(a%10|48);}
inline void wln(int a){wri(a);puts("");}
int n,x,y,c[N],col[N],num[N],f[N][2],a[N],i;
vector<int>G[N],fac,dao;
bool fl=1;
void dfs(int u,int fa){
	col[u]=a[u];
	num[u]=a[u]>0;
	FOR(v){
		dfs(v,u);
		if (!col[v]) continue;
		if (col[u] && col[u]!=col[v]) fl=0;
		col[u]=col[v];
		num[u]+=num[v];
	}
	if (col[u]){
		f[u][1]=1;
		FOR(v) f[u][1]=1ll*f[u][1]*f[v][col[v]>0]%M;
	}else{
		fac.clear(),dao.clear();
		f[u][0]=1,fac.push_back(1);
		FOR(v) f[u][0]=1ll*f[u][0]*f[v][0]%M,fac.push_back(f[v][0]),dao.push_back(f[v][0]);
		else fac.push_back(1),dao.push_back(1);
		dao.push_back(1);
		for (int i=1;i<fac.size();i++) fac[i]=1ll*fac[i-1]*fac[i]%M;
		for (int i=dao.size()-2;~i;i--) dao[i]=1ll*dao[i+1]*dao[i]%M;
		FOR(v) f[u][1]=(1ll*f[v][1]*fac[i]%M*dao[i+1]+f[u][1])%M;
	}
	if (!col[u] || c[col[u]]==num[u]){
		f[u][0]=(f[u][0]+f[u][1])%M;
		col[u]=num[u]=0;
	}
}
int main(){
	n=rd(),rd();
	for (i=0;i<n;i++) c[a[i]=rd()]++;
	for (i=1;i<n;i++) x=rd()-1,y=rd()-1,G[x].push_back(y),G[y].push_back(x);
	dfs(0,-1);
	printf("%d",f[0][1]*fl);
}

猜你喜欢

转载自blog.csdn.net/xumingyang0/article/details/87875457