Codeforces 997D Cycles in Product (点分治、DP计数)

题目链接

https://codeforces.com/contest/997/problem/D

题解

点分治这个思路想不到==

首先这两棵树的笛卡尔积并没有什么用处,因为笛卡尔积中的环就是两棵树中各找一个环按任意顺序归并起来(且不难证明不同的归并顺序对应不同的方案)。只需要对两棵树分别求出 \(ans_i\) 表示有多少个长度为 \(i\) 的环。

注意由于 1213 这种环的存在,我们不能直接钦定某个具有特殊性质的点为起点而忽略起点然后乘以环长再去掉周期更小的。一个 naive 的想法是设 \(dp[u_1][u_2][k]\) 表示起点为 \(u_1\) 当前在 \(u_2\) 长度为 \(k\) 的方案数。考虑点分治,设分治中心为 \(cent\),计算经过 \(cent\) 的环的个数。这个很容易(很套路),设 \(f[u][k]\) 表示 \(u\) 点走了 \(k\) 步到 \(cent\) 且之前没到过 \(cent\) 的方案数,\(g[u][k]\) 表示 \(cent\) 走了 \(k\) 步到 \(u\) 的方案数,转移显然,求答案枚举第一次到 \(cent\) 的时间即可。

时间复杂度 \(O(k^2n\log n)\).

代码

#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define iter iterator
#define riter reversed_iterator
#define y1 Lorem_ipsum_dolor
using namespace std;

inline int read()
{
	int x = 0,f = 1; char ch = getchar();
	for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
	for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
	return x*f;
}

const int mxN = 4000;
const int mxM = 75;
const int P = 998244353;

llong comb[mxN+3][mxN+3];
int m;

void updsum(llong &x,llong y) {x = x+y>=P?x+y-P:x+y;}

void initcomb(int n)
{
	comb[0][0] = 1ll;
	for(int i=1; i<=n; i++) {comb[i][0] = comb[i][i] = 1ll; for(int j=1; j<i; j++) updsum(comb[i][j]=comb[i-1][j-1],comb[i-1][j]);}
}

struct Tree
{
	struct Edge
	{
		int v,nxt;
	} e[(mxN<<1)+3];
	int fe[mxN+3];
	int fa[mxN+3];
	int sz[mxN+3],mxsz[mxN+3]; bool vis[mxN+3];
	llong f[mxM+3][mxN+3],g[mxM+3][mxN+3];
	llong ans[mxN+3];
	vector<int> now;
	int n,en,cent,siz;
	void addedge(int u,int v)
	{
		en++; e[en].v = v;
		e[en].nxt = fe[u]; fe[u] = en;
	}
	void getCentroid(int u,int prv)
	{
		now.push_back(u);
		sz[u] = 1,mxsz[u] = 0;
		for(int i=fe[u]; i; i=e[i].nxt)
		{
			int v = e[i].v; if(v==prv||vis[v]) continue;
			getCentroid(v,u);
			sz[u] += sz[v];
			mxsz[u] = max(mxsz[u],sz[v]);
		}
		mxsz[u] = max(mxsz[u],siz-sz[u]);
		if(cent==0||mxsz[u]<mxsz[cent]) {cent = u;}
	}
	void findCentroid(int u)
	{
		now.clear(); cent = 0; getCentroid(u,0);
	}
	void dfs1()
	{
//		printf("dfs1 %d\n",cent);
//		printf("now: "); for(int i=0; i<now.size(); i++) printf("%d ",now[i]); puts("");
		int tsiz = siz;
		f[0][cent] = 1ll,g[0][cent] = 1ll;
		for(int i=0; i<m; i++)
		{
			for(int j=0; j<now.size(); j++)
			{
				int u = now[j];
				for(int o=fe[u]; o; o=e[o].nxt)
				{
					int v = e[o].v; if(vis[v]) continue;
					if(v!=cent) {updsum(f[i+1][v],f[i][u]);}
					updsum(g[i+1][v],g[i][u]);
				}
			}
		}
		for(int i=0; i<now.size(); i++)
		{
			int u = now[i];
			for(int j=0; j<=m; j++)
			{
				for(int k=0; k<=j; k++)
				{
					updsum(ans[j],f[k][u]*g[j-k][u]%P);
				}
			}
		}
		for(int i=0; i<now.size(); i++) for(int j=0; j<=m; j++) {f[j][now[i]] = g[j][now[i]] = 0ll;}
		vis[cent] = true;
		for(int i=fe[cent]; i; i=e[i].nxt)
		{
			int v = e[i].v; if(vis[v]) continue;
			siz = sz[v]>sz[cent]?tsiz-sz[cent]:sz[v]; findCentroid(v);
			dfs1();
		}
	}
	void solve()
	{
		for(int i=1; i<n; i++) {int u = read(),v = read(); addedge(u,v),addedge(v,u);}
		siz = n; findCentroid(1); dfs1();
//		printf("ans: "); for(int i=0; i<=m; i++) printf("%I64d ",ans[i]); puts("");
	}
} T1,T2;

int main()
{
	T1.n = read(); T2.n = read(); m = read();
	initcomb(m);
	T1.solve();
	T2.solve();
	llong ans = 0ll;
	for(int i=0; i<=m; i++)
	{
		updsum(ans,comb[m][i]*T1.ans[i]%P*T2.ans[m-i]%P);
	}
	printf("%I64d\n",ans);
	return 0;
}

猜你喜欢

转载自www.cnblogs.com/suncongbo/p/12657120.html