【Nowcoder】牛客练习赛1 B-树 | 树形dp、组合数学

题目大意:

shy有一颗树,树有n个结点。有k种不同颜色的染料给树染色。一个染色方案是合法的,当且仅当对于所有相同颜色的点对(x,y),x到y的路径上的所有点的颜色都要与x和y相同。请统计方案数。

题目思路:

把题目要求转换一下,对于每一种颜色均为一个连通块。

否则,则不满足要求

所以可以考虑,把这个树分成 k k k个连通块的方案数是多少

对于分成 k k k个连通块的每一个方案,有 m m m种颜色,那么方案数自然为: A m k A_m^k Amk

所以之需要求出,将树划分为 1.... k 1....k 1....k个连通块的方案数,对于每一个 k k k, a n s = a n s + c a l ( k ) ∗ A m k ans = ans + cal(k)*A_m^k ans=ans+cal(k)Amk

这个题就解决了

这里如何求出把树划分为k个连通块的方案数呢

两种方法:

1.树形dp

d p [ u ] [ k ] dp[u][k] dp[u][k]代表以u为根的子树,划分为 k k k个连通块的方案数

那么状态转移很显然,对于每一个 u u u的孩子 e e e

该子树有两种情况:

  • 融入上次的连通块
  • 不融入上次的连通块

所以就有:

		for(int i=1;i<=min(sz[u],m);i++){
    
    ///枚举子树大小
			for(int k=1;k<=min(sz[e],m);k++){
    
    
				if(i+k-1<=m) t[u][i+k-1] = (t[u][i+k-1] + (dp[u][i] * dp[e][k]) )%mod; 
			}
		}
		for(int i=1;i<=min(sz[u],m);i++){
    
    ///枚举子树大小
			for(int k=1;k<=min(sz[e],m);k++){
    
    
				if(i+k<=m) t[u][i+k] = (t[u][i+k] + (dp[u][i] * dp[e][k]) )%mod; 
			}
		}
		for(int k=1;k<=m;k++) dp[u][k] = t[u][k],t[u][k] = 0;

这样复杂度是 O ( n ∗ m ) O(n*m) O(nm)

2.组合数学

考虑把一颗树划分为 k k k个连通块,无非就是切断一棵树的 ( k − 1 ) (k-1) (k1)条边

所以在 n − 1 n-1 n1条边种 k − 1 k-1 k1条边,那么方案数就是 C n − 1 k − 1 C_{n-1}^{k-1} Cn1k1

然后再与 A m k A_m^k Amk相乘就好了

最后附一下代码, d f s ( 1 , 1 ) dfs(1,1) dfs(1,1)之后 d p [ 1 ] [ i ] dp[1][i] dp[1][i]就代表划分为i个连通块的方案数

Code:

/*** keep hungry and calm CoolGuang!  ***/
#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline")
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define debug(x) cout<<#x<<":"<<x<<endl;
#define dl(x) printf("%lld\n",x);
#define di(x) printf("%d\n",x);
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const ll INF= 1e17+7;
const ll maxn =2e5+700;
const ll mod= 1e9+7;
const ll up = 1e13;
const double eps = 1e-9;
template<typename T>inline void read(T &a){
    
    char c=getchar();T x=0,f=1;while(!isdigit(c)){
    
    if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){
    
    x=(x<<1)+(x<<3)+c-'0';c=getchar();}a=f*x;}
ll n,m,p;
vector<int>v[maxn];
ll dp[305][305];
ll t[305][305];///代替分组背包
ll sz[maxn];
ll s[2005][2005];
ll cal(ll x,ll y){
    
    
 	if(x<y) return 0;
    if(y == 0 || x == y) return s[x][y] = 1;
    if(y == 1) return s[x][y] = x%mod;
    if(~s[x][y]) return s[x][y];
    return s[x][y] = (cal(x-1,y) + cal(x-1,y-1))%mod;
}
void dfs(int u,int fa){
    
    
	sz[u] = 1;
	dp[u][1] = 1;
	for(int e:v[u]){
    
    
		if(e == fa) continue;
	
		///对于新来的任何一个子数 都有与当前合并 和 不与当前合并
		dfs(e,u);

		for(int i=1;i<=min(sz[u],m);i++){
    
    ///枚举子树大小
			for(int k=1;k<=min(sz[e],m);k++){
    
    
				if(i+k-1<=m) t[u][i+k-1] = (t[u][i+k-1] + (dp[u][i] * dp[e][k]) )%mod; 
			}
		}
		for(int i=1;i<=min(sz[u],m);i++){
    
    ///枚举子树大小
			for(int k=1;k<=min(sz[e],m);k++){
    
    
				if(i+k<=m) t[u][i+k] = (t[u][i+k] + (dp[u][i] * dp[e][k]) )%mod; 
			}
		}
		for(int k=1;k<=m;k++) dp[u][k] = t[u][k],t[u][k] = 0;
		sz[u] += sz[e];
	}
}
ll A[maxn];
int main(){
    
    
	memset(s,-1,sizeof(s));
	read(n);read(m);
	for(int i=1;i<=n-1;i++){
    
    
		int x,y;read(x);read(y);
		v[x].push_back(y);
		v[y].push_back(x);
	}
	
	A[0] = 1;
	for(int i=1;i<=m;i++) A[i] = (A[i-1] * (m-i+1))%mod;
	ll ans = 0;
	for(int i=1;i<=m;i++){
    
    
		ans = (ans + (cal(n-1,i-1)*A[i])%mod)%mod;
	}
	printf("%lld\n",ans);
	return 0;
}
/***
ababd
abd
***/

猜你喜欢

转载自blog.csdn.net/qq_43857314/article/details/112277083