[CF995F] Cowmpany Cowmpensation(树形dp,拉格朗日插值)

树形DP:
f [ u ] [ i ] f[u][i] f[u][i]表示给 u u u的子树分配工资, u u u点工资为 i i i的方案数
f [ u ] [ i ] = ∏ v ∈ s o n u ( ∑ j = 1 i f [ v ] [ j ] ) f[u][i]=\prod\limits_{v\in son_u}(\sum\limits_{j=1}^{i}f[v][j]) f[u][i]=vsonu(j=1if[v][j])
前缀和优化:
g [ u ] [ i ] = ∑ j = 1 i f [ u ] [ j ] g[u][i]=\sum\limits_{j=1}^{i}f[u][j] g[u][i]=j=1if[u][j]
f [ u ] [ i ] = ∏ v ∈ s o n u g [ v ] [ i ] f[u][i]=\prod\limits_{v\in son_u}g[v][i] f[u][i]=vsonug[v][i]
时间复杂度 O ( n d ) O(nd) O(nd)
考虑用拉格朗日插值优化 O ( n 2 ) O(n^2) O(n2)
g u ( x ) g_u(x) gu(x)为关于 x x x的函数,代入 x x x,即可得到 g [ u ] [ x ] g[u][x] g[u][x]
合理猜想 g u ( x ) g_u(x) gu(x)的次数为 s z u sz_u szu u u u的子树大小),可以用数学归纳法证明:

  • u u u 为叶子结点时, g u ( x ) = x g_u(x)=x gu(x)=x,成立。
  • u u u 非叶子结点时,考虑:
    g u ( x ) − g u ( x − 1 ) = ∏ v ∈ s o n u g v ( x ) g_u(x)-g_u(x-1)=\prod\limits_{v\in son_u}g_v(x) gu(x)gu(x1)=vsonugv(x)
    由于 v v v 满足猜想,即 g v ( x ) g_v(x) gv(x) 的次数为 s z v sz_v szv,则 g u ( x ) − g u ( x − 1 ) g_u(x)-g_u(x-1) gu(x)gu(x1) 的次数为 ∑ v ∈ s o n u s z v \sum\limits_{v\in son_u}sz_v vsonuszv,即 s z u − 1 sz_u-1 szu1
    再还原差分,次数 +1,有 g u ( x ) g_u(x) gu(x) 是关于 x x x s z u − 1 + 1 = s z u sz_u-1+1=sz_u szu1+1=szu 次函数。

所以我们只要知道 g 1 ( 1 ) , g 1 ( 2 ) , . . . , g 1 ( n ) g_1(1),g_1(2),...,g_1(n) g1(1),g1(2),...,g1(n),便可以用拉格朗日插值得出 g 1 ( d ) g_1(d) g1(d)

#include<iostream>
#include<cstdio>
using namespace std;
typedef long long ll;
const int mod=1e9+7;
const int N=3050;
struct Edge{
    
    int v,nxt;}edge[N];
int n,cnt,head[N],fa[N];
int d,f[N][N],sum[N][N],inv[N],ans;
int add(int a,int b){
    
    return a+b>=mod?a+b-mod:a+b;}
int mul(int a,int b){
    
    return 1ll*a*b%mod;}
void addedge(int u,int v){
    
    
    edge[++cnt].v=v;edge[cnt].nxt=head[u];head[u]=cnt;
}
void dfs(int u){
    
    
    for(int i=head[u];i;i=edge[i].nxt){
    
    
        int v=edge[i].v;
        dfs(v);
        for(int j=1;j<=min(n,d);j++)
            f[u][j]=mul(f[u][j],sum[v][j]);
    }
    for(int j=1;j<=min(n,d);j++) sum[u][j]=add(sum[u][j-1],f[u][j]);
}
int main(){
    
    
    scanf("%d%d",&n,&d);
    inv[1]=1;
    for(int i=2;i<=n;++i) inv[i]=mul(mod-mod/i,inv[mod%i]);
    for(int i=2;i<=n;i++){
    
    
        scanf("%d",&fa[i]);
        addedge(fa[i],i);
    }
    for(int i=1;i<=n;i++)
        for(int j=1;j<=min(n,d);j++) f[i][j]=1;
    dfs(1);
    if(d<=n) printf("%d\n",sum[1][d]);
    else{
    
    
        for(int i=1;i<=n;i++){
    
    
            int x=sum[1][i];
            for(int j=0;j<=n;++j) 
                if(i^j) x=1ll*x*(d-j+mod)%mod*(i>j?inv[i-j]:mod-inv[j-i])%mod;
            ans=add(ans,x);
        }
        printf("%d\n",ans);
    }
    return 0;
}

Guess you like

Origin blog.csdn.net/Emma2oo6/article/details/121400474