HDU 4625 JZPTREE 数学 - 斯特林数学习笔记

题目大意:给你一颗树,对每一个点x求所有点到其距离的k次方之和。 n 50000 ,   k 500

斯特林数的一个应用,先考虑 O ( n k 2 ) 暴力怎么做,例如,求出x到其子树中所有点的距离k次之和(然后再转移出到所有点的答案即可,过程类似):

d p k [ x ] = y T r e e x d i s ( x , y ) k

= y s o n x z T r e e y ( d i s ( y , z ) + 1 ) k

= y s o n x z T r e e y i = 0 k ( k i ) d i s ( y , z ) i

= y s o n x i = 0 k ( k i ) z T r e e y d i s ( y , z ) i

= y s o n x i = 0 k ( k i ) d p i [ y ]

按照上述结论转移即可做到 O ( n k 2 )
考虑优化,注意到:
n k = i = 1 k ( n i ) × S ( k , i ) × i !

n k = i = 1 k S ( k , i ) × [ n ] i

(实际上i到底枚举到n还是k是不会影响答案的,但是这里k比较小,多枚举没有意义,因此只枚举到k)。其中S(n,k)为第二类斯特林数,表示把n个有标号球放到k个相同的非空盒子里的方案数,显然有:
S ( n , k ) = S ( n 1 , k 1 ) + S ( n 1 , k ) × k

其意义是,要么第n个求单独一个盒子,要么和之前某个非空的盒子一组。
上一个式子的意义是,把k个带标号球放到n个带标号盒子里,显然有 n k 种方法,或者可以这么计数:先决定放到哪 1 i k 个盒子里,然后用斯特林数放进去,但这样是不带标号盒子,所以要乘上一个阶乘。最后那个[n]_i叫做n的i次下降幂,其实就是 P ( n , i ) = n ( n 1 ) ( n 2 ) . . . ( n i + 1 ) 。这告诉我们,对于k次幂计数,可以转化为对k次下降幂(其实就是组合数计数,二者只差一个阶乘)计数,然后用O(k)的时间转化回k次幂计数(事实上,用第一类斯特林数,也可以将下降幂计数转化成幂计数,再用O(k)的时间转化回去)。而组合数计数有一个好处:
( n k ) = ( n 1 k ) + ( n 1 k 1 )

把之前k次幂转化为组合数乘以斯特林数再乘上阶乘的式子代入,我们得到(略去化简):
d p k [ x ] = i = 1 k S ( k , i ) i ! y T r e e x ( d i s ( x , y ) i )

所以为题转化为,对后面一个sigma计数。
令(这里字母和代码中字母不同):
f k [ x ] = y T r e e x ( d i s ( x , y ) k ) = y s o n x z T r e e y ( d i s ( y , z ) + 1 k )

注意这里的转化没有考虑dis(x,x),也就是当k=0的时候特殊处理一下。
应用上文提到的关于组合数的帕斯卡等式,有:
f k [ x ] = y s o n x z T r e e y [ ( d i s ( y , z ) k ) + ( d i s ( y , z ) k 1 ) ]

f k [ x ] = y s o n x ( f k [ y ] + f k 1 [ y ] )

因此这种转化的好处是,将递推的时间复杂度优化掉一个k,尽管还要再转化回去的复杂度会多一个k,但是对于 O ( n k 2 ) O ( n ) 的算法,改成 O ( n k ) O ( n k ) 显然更优。
代码:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define lint long long
#define mod 10007
#define N 50010
#define K 510
using namespace std;
struct edges{
    int to,pre;
}e[N<<1];int fac[K],s[K][K],dp[N][K],ans[N][K],etop,h[N];
inline int add_edge(int u,int v)
{   return e[++etop].to=v,e[etop].pre=h[u],h[u]=etop;   }
inline int prelude(int n)
{
    for(int i=fac[0]=1;i<=n;i++) fac[i]=(lint)fac[i-1]*i%mod;
    s[0][0]=1;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=i;j++)
            s[i][j]=(s[i-1][j-1]+j*s[i-1][j])%mod;
    return 0;
}
int get_dp(int x,int f,int k)
{
    memset(dp[x],0,sizeof(int)*(k+1)),dp[x][0]=1;
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)^f)
        {
            get_dp(y,x,k),dp[x][0]+=dp[y][0];
            for(int j=1;j<=k;j++) dp[x][j]+=dp[y][j]+dp[y][j-1];
        }
    for(int i=1;i<=k;i++) (dp[x][i]>=mod?dp[x][i]%=mod:0);return 0;
}
inline int f(int x,int y,int i)
{   return ans[x][i]-dp[y][i]-(i>0?dp[y][i-1]:0)+mod*2; }
int get_ans(int x,int fa,int k)
{
    for(int i=0;i<=k;i++) (ans[x][i]+=dp[x][i])%=mod;
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)^fa)
        {
            ans[y][0]=ans[x][0]-dp[y][0];
            for(int j=1;j<=k;j++)
                ans[y][j]=f(x,y,j)+f(x,y,j-1);
            get_ans(y,x,k);
        }
    return 0;
}
inline int calc(int *ans,int k)
{
    lint res=0ll;
    for(int i=1;i<=k;i++) res+=(lint)fac[i]*s[k][i]*ans[i];
    return res%mod;
}
int main()
{
    int T;scanf("%d",&T),prelude(500);
    while(T--)
    {
        int n,k;scanf("%d%d",&n,&k);
        memset(h,0,sizeof(int)*(n+1)),etop=0;
        for(int i=1;i<n;i++)
        {
            int u,v;scanf("%d%d",&u,&v);
            add_edge(u,v),add_edge(v,u);
        }
        int rt=1;memset(ans[rt],0,sizeof(int)*(k+1));
        get_dp(rt,0,k),get_ans(rt,0,k);
        for(int i=1;i<=n;i++) printf("%d\n",calc(ans[i],k));
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/mys_c_k/article/details/79942486