Codeforces Round #446 (Div. 1)

D. Sloth

让我们暴力树形dp吧,题解的性质+分类讨论太难想了
推清楚所有转移细节!!
f[i][0/1][0/1][0/1/2/3] : 当前以i为根的子树,根是否匹配,子树中是否有未匹配点,删边加边的状态
删边加边的状态:
0:未删除
1: 删除且那块中有未匹配点
2: 删除且无未匹配点
3: 已经删除和加边
方案数要在删除和加入的时候*sz(如果可以任选点连的话)
转移分几类:直接匹配,在任何情况下都可以,不要漏掉了!,一个删除在另一块中加入,当前子树删除后直接加边或者留着以后加
细节很多,需要思维非常清晰
调了50min还是很不应该,提高效率!

//非常复杂的dp细节,要考虑全面所有情况
//复杂的状态宁愿手推转移,不要循环。主要是逻辑要清晰
//两个子树合并的时候除了删边和加边的拼接,还有直接匹配的转移方式,漏了好几次!
//把所有情况按类分开,转移可以边写代码边推完整,草稿纸上框出所有大类
//一定要边写转移边写注释,并且要仔细检查代码。会发现很多小错误!
#include<bits/stdc++.h>
using namespace std;
#define maxn 500020
#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define rvc(i,S) for(register int i = 0 ; i < (int)S.size() ; i++)
#define rvcd(i,S) for(register int i = ((int)S.size()) - 1 ; i >= 0 ; i--)
#define fore(i,x)for (register int i = head[x] ; i ; i = e[i].next)
#define pb push_back
#define prev prev_
#define stack stack_
#define mp make_pair
#define fi first
#define se second
#define inf 0x3f3f3f3f
typedef long long ll;
typedef pair<int,int> pr;

struct node{
	int next,to;
}e[maxn * 2];
int head[maxn],cnt,fa[maxn],n,sz[maxn];
ll f[maxn][2][2][4],g[2][2][4];

inline void adde(int x,int y){
	e[++cnt].to = y;
	e[cnt].next = head[x];
	head[x] = cnt;
}
void dfs(int x){
	sz[x] = 1;
	f[x][0][0][0] = 1;
	fore(i,x){
		if ( e[i].to == fa[x] ) continue;
		fa[e[i].to] = x;
		dfs(e[i].to);
		memcpy(g,f[x],sizeof(f[x]));
		memset(f[x],0,sizeof(f[x]));
		//x and e[i].to both arn't delete
		//e[i].to可以作为未匹配点,而不一定立即和根匹配
		//根匹配,无未匹配点
		f[x][1][0][0] += g[1][0][0] * f[e[i].to][1][0][0] + g[0][0][0] * f[e[i].to][0][0][0]; 
		//根匹配,有未匹配点
		f[x][1][1][0] += g[1][1][0] * f[e[i].to][1][0][0] + g[1][0][0] * (f[e[i].to][1][1][0] + f[e[i].to][0][0][0]) + g[0][1][0] * f[e[i].to][0][0][0] + g[0][0][0] * f[e[i].to][0][1][0];
		//根未匹配,有无未匹配点
		f[x][0][1][0] += g[0][1][0] * f[e[i].to][1][0][0] + g[0][0][0] * (f[e[i].to][0][0][0] + f[e[i].to][1][1][0]);
		f[x][0][0][0] += g[0][0][0] * f[e[i].to][1][0][0];
		//正常匹配
			//x已经完成了删边和加边
			f[x][0][0][3] += g[0][0][3] * f[e[i].to][1][0][0];
			f[x][1][0][3] += g[0][0][3] * f[e[i].to][0][0][0] + g[1][0][3] * f[e[i].to][1][0][0];
			//e[i].to已经完成了删边和加边
			f[x][0][0][3] += g[0][0][0] * f[e[i].to][1][0][3];
			f[x][1][0][3] += g[0][0][0] * f[e[i].to][0][0][3] + g[1][0][0] * f[e[i].to][1][0][3];
			//x已经删边
			rep(t,1,2){
				f[x][0][0][t] += g[0][0][t] * f[e[i].to][1][0][0];
				f[x][1][0][t] += g[0][0][t] * f[e[i].to][0][0][0] + g[1][0][t] * f[e[i].to][1][0][0];
			}
			//e[i].to已经删边
			rep(t,1,2){
				f[x][0][0][t] += g[0][0][0] * f[e[i].to][1][0][t];
				f[x][1][0][t] += g[0][0][0] * f[e[i].to][0][0][t] + g[1][0][0] * f[e[i].to][1][0][t];
			}
		//完成拼接后,不能再有未匹配点
		//x已经删边,接到e[i].to中
		//删边的那块中有未匹配点
			//接到e[i].to上
			f[x][0][0][3] += g[0][0][1] * f[e[i].to][0][0][0];
			f[x][1][0][3] += g[1][0][1] * f[e[i].to][0][0][0];
			//接到e[i].to子树中的未匹配点
			f[x][0][0][3] += g[0][0][1] * f[e[i].to][1][1][0];
			f[x][1][0][3] += g[1][0][1] * f[e[i].to][1][1][0] + g[0][0][1] * f[e[i].to][0][1][0];
		//删边的那块无未匹配点
			//接到任意地方,e[i].to所有点必须匹配
			f[x][0][0][3] += g[0][0][2] * f[e[i].to][1][0][0] * sz[e[i].to];
			f[x][1][0][3] += (g[0][0][2] * f[e[i].to][0][0][0] + g[1][0][2] * f[e[i].to][1][0][0]) * sz[e[i].to];
		//e[i].to已经删边,接到x之前的子树中
		//有未匹配点	
			//接到x上
			f[x][1][0][3] += g[0][0][0] * f[e[i].to][1][0][1];
			//接到x子树内的未匹配点
			f[x][0][0][3] += g[0][1][0] * f[e[i].to][1][0][1];
			f[x][1][0][3] += g[1][1][0] * f[e[i].to][1][0][1] + g[0][1][0] * f[e[i].to][0][0][1];
		//无未匹配点
			//任意接到x内
			f[x][0][0][3] += g[0][0][0] * f[e[i].to][1][0][2] * sz[x];
			f[x][1][0][3] += (g[0][0][0] * f[e[i].to][0][0][2] + g[1][0][0] * f[e[i].to][1][0][2]) * sz[x];

		//删除当前子树
			//立即拼接
			f[x][0][0][3] += g[0][1][0] * (f[e[i].to][1][1][0] + f[e[i].to][0][0][0]) + g[0][0][0] * f[e[i].to][1][0][0] * sz[e[i].to] * sz[x];
			f[x][1][0][3] += g[0][0][0] * (f[e[i].to][1][1][0] + f[e[i].to][0][0][0]) + g[1][1][0] * (f[e[i].to][1][1][0] + f[e[i].to][0][0][0]) + g[1][0][0] * f[e[i].to][1][0][0] * sz[e[i].to] * sz[x];
			//留着以后拼接	
			f[x][0][0][1] += g[0][0][0] * (f[e[i].to][1][1][0] + f[e[i].to][0][0][0]);
			f[x][1][0][1] += g[1][0][0] * (f[e[i].to][1][1][0] + f[e[i].to][0][0][0]);
			f[x][0][0][2] += g[0][0][0] * f[e[i].to][1][0][0] * sz[e[i].to];
			f[x][1][0][2] += g[1][0][0] * f[e[i].to][1][0][0] * sz[e[i].to];
		
	

		sz[x] += sz[e[i].to];
	}
}
int main(){
	scanf("%d",&n);
	rep(i,1,n - 1){
		int x,y;
		scanf("%d %d",&x,&y);
		adde(x,y) , adde(y,x);
	}
	dfs(1);
	ll ans = f[1][1][0][3];
	printf("%lld\n",ans);
}

E - Lust

被这道题虐爆了!思路太巧了

首先观察出ans = ai最初的乘积 - 最后的乘积,用归纳法证明,一开始没有想到

然后写出一个暴力dp:f[S][k] = f[S][k - 1] - (1 / n) * f[S - 2 ^ i][k - 1]

表示第k轮只考虑S的数的期望乘积和。

这一步的目的是把k轮可以快速幂掉,再考虑dp怎么优化

观察dp转移,发现相当于在2 ^ n个点的图上,从起点走到2 ^ n - 1的方案,走 + 2 ^ i 的边时乘(-1/n),走自环不变。
于是就可以直接从最初状态算到末状态的贡献系数,并且只要集合中元素个数相同贡献系数相同,n^2 DP一下就好

题解上把矩阵的k次方看成线性映射,把出状态和末状态看成2^n维空间的向量,从线代的思想出发看待dp转移,可以学习这个思路

一开始想dp枚举每个数被减了几次,从方向上就错了!!

#include<bits/stdc++.h>
using namespace std;
#define maxn 5020
#define rep(i,l,r) for(register int i = l ; i <= r ; i++)
#define repd(i,r,l) for(register int i = r ; i >= l ; i--)
#define rvc(i,S) for(register int i = 0 ; i < (int)S.size() ; i++)
#define rvcd(i,S) for(register int i = ((int)S.size()) - 1 ; i >= 0 ; i--)
#define fore(i,x)for (register int i = head[x] ; i ; i = e[i].next)
#define pb push_back
#define prev prev_
#define stack stack_
#define mp make_pair
#define fi first
#define se second
#define inf 0x3f3f3f3f
typedef long long ll;
typedef pair<int,int> pr;

const ll mod = 1e9 + 7;
ll f[maxn][maxn],a[maxn],powk[maxn],ans,ans2,invn;
int n,k;

inline ll power(ll x,ll y){
	ll res = 1;
	while ( y ){
		if ( y & 1 ) res = res * x % mod;
		x = x * x % mod;
		y >>= 1;
	}
	return res;
}
void init(){
	powk[0] = 1;
	rep(i,1,n) powk[i] = powk[i - 1] * (k - i + 1) % mod;
	invn = mod - power(n,mod - 2);
}
int main(){
	scanf("%d %d",&n,&k);
	init();
	ans = 1;
	rep(i,1,n) scanf("%lld",&a[i]) , ans = ans * a[i] % mod;
	f[0][0] = 1;
	rep(i,1,n){
		f[i][0] = 1;
		rep(j,1,i){
			f[i][j] = (f[i - 1][j] + f[i - 1][j - 1] * a[i]) % mod;
		}
	}
	rep(i,0,n){
		ans2 = (ans2 + f[n][i] * power(invn,n - i) % mod * powk[n - i]) % mod;
	}
	cout<<(ans - ans2 +mod) % mod<<endl;
}

猜你喜欢

转载自blog.csdn.net/weixin_42484877/article/details/86352076