题目大意:
shy有一颗树,树有n个结点。有k种不同颜色的染料给树染色。一个染色方案是合法的,当且仅当对于所有相同颜色的点对(x,y),x到y的路径上的所有点的颜色都要与x和y相同。请统计方案数。
题目思路:
把题目要求转换一下,对于每一种颜色均为一个连通块。
否则,则不满足要求
所以可以考虑,把这个树分成 k k k个连通块的方案数是多少
对于分成 k k k个连通块的每一个方案,有 m m m种颜色,那么方案数自然为: A m k A_m^k Amk
所以之需要求出,将树划分为 1.... k 1....k 1....k个连通块的方案数,对于每一个 k k k, a n s = a n s + c a l ( k ) ∗ A m k ans = ans + cal(k)*A_m^k ans=ans+cal(k)∗Amk
这个题就解决了
这里如何求出把树划分为k个连通块的方案数呢
两种方法:
1.树形dp
d p [ u ] [ k ] dp[u][k] dp[u][k]代表以u为根的子树,划分为 k k k个连通块的方案数
那么状态转移很显然,对于每一个 u u u的孩子 e e e
该子树有两种情况:
- 融入上次的连通块
- 不融入上次的连通块
所以就有:
for(int i=1;i<=min(sz[u],m);i++){
///枚举子树大小
for(int k=1;k<=min(sz[e],m);k++){
if(i+k-1<=m) t[u][i+k-1] = (t[u][i+k-1] + (dp[u][i] * dp[e][k]) )%mod;
}
}
for(int i=1;i<=min(sz[u],m);i++){
///枚举子树大小
for(int k=1;k<=min(sz[e],m);k++){
if(i+k<=m) t[u][i+k] = (t[u][i+k] + (dp[u][i] * dp[e][k]) )%mod;
}
}
for(int k=1;k<=m;k++) dp[u][k] = t[u][k],t[u][k] = 0;
这样复杂度是 O ( n ∗ m ) O(n*m) O(n∗m)的
2.组合数学
考虑把一颗树划分为 k k k个连通块,无非就是切断一棵树的 ( k − 1 ) (k-1) (k−1)条边
所以在 n − 1 n-1 n−1条边种 k − 1 k-1 k−1条边,那么方案数就是 C n − 1 k − 1 C_{n-1}^{k-1} Cn−1k−1
然后再与 A m k A_m^k Amk相乘就好了
最后附一下代码, d f s ( 1 , 1 ) dfs(1,1) dfs(1,1)之后 d p [ 1 ] [ i ] dp[1][i] dp[1][i]就代表划分为i个连通块的方案数
Code:
/*** keep hungry and calm CoolGuang! ***/
#pragma GCC optimize("Ofast","unroll-loops","omit-frame-pointer","inline")
#pragma GCC optimize(3)
#include <bits/stdc++.h>
#include<stdio.h>
#include<queue>
#include<algorithm>
#include<string.h>
#include<iostream>
#define debug(x) cout<<#x<<":"<<x<<endl;
#define dl(x) printf("%lld\n",x);
#define di(x) printf("%d\n",x);
typedef long long ll;
typedef unsigned long long ull;
using namespace std;
const ll INF= 1e17+7;
const ll maxn =2e5+700;
const ll mod= 1e9+7;
const ll up = 1e13;
const double eps = 1e-9;
template<typename T>inline void read(T &a){
char c=getchar();T x=0,f=1;while(!isdigit(c)){
if(c=='-')f=-1;c=getchar();}
while(isdigit(c)){
x=(x<<1)+(x<<3)+c-'0';c=getchar();}a=f*x;}
ll n,m,p;
vector<int>v[maxn];
ll dp[305][305];
ll t[305][305];///代替分组背包
ll sz[maxn];
ll s[2005][2005];
ll cal(ll x,ll y){
if(x<y) return 0;
if(y == 0 || x == y) return s[x][y] = 1;
if(y == 1) return s[x][y] = x%mod;
if(~s[x][y]) return s[x][y];
return s[x][y] = (cal(x-1,y) + cal(x-1,y-1))%mod;
}
void dfs(int u,int fa){
sz[u] = 1;
dp[u][1] = 1;
for(int e:v[u]){
if(e == fa) continue;
///对于新来的任何一个子数 都有与当前合并 和 不与当前合并
dfs(e,u);
for(int i=1;i<=min(sz[u],m);i++){
///枚举子树大小
for(int k=1;k<=min(sz[e],m);k++){
if(i+k-1<=m) t[u][i+k-1] = (t[u][i+k-1] + (dp[u][i] * dp[e][k]) )%mod;
}
}
for(int i=1;i<=min(sz[u],m);i++){
///枚举子树大小
for(int k=1;k<=min(sz[e],m);k++){
if(i+k<=m) t[u][i+k] = (t[u][i+k] + (dp[u][i] * dp[e][k]) )%mod;
}
}
for(int k=1;k<=m;k++) dp[u][k] = t[u][k],t[u][k] = 0;
sz[u] += sz[e];
}
}
ll A[maxn];
int main(){
memset(s,-1,sizeof(s));
read(n);read(m);
for(int i=1;i<=n-1;i++){
int x,y;read(x);read(y);
v[x].push_back(y);
v[y].push_back(x);
}
A[0] = 1;
for(int i=1;i<=m;i++) A[i] = (A[i-1] * (m-i+1))%mod;
ll ans = 0;
for(int i=1;i<=m;i++){
ans = (ans + (cal(n-1,i-1)*A[i])%mod)%mod;
}
printf("%lld\n",ans);
return 0;
}
/***
ababd
abd
***/