@loj - 2320@ 「清华集训 2017」生成树计数


@description@

在一个 s 个点的图中,存在 s - n 条边,使图中形成了 n 个连通块,第 i 个连通块中有 \(a_i\) 个点。

现在我们需要再连接 n - 1 条边,使该图变成一棵树。对一种连边方案,设原图中第 i 个连通块连出了 \(d_i\) 条边,那么这棵树 T 的价值为:

\[val(T) = (\prod_{i=1}^{n}d_{i}^{m})(\sum_{i=1}^{n}d_{i}^{m})\]

你的任务是求出所有可能的生成树的价值之和,对 998244353 取模。

原题戳我

@solution@

@正文@

注意到 \(d_i\) 为度数,那么考虑 prufer 序列,直接写出答案表达式:

\[ans = \sum_{(\sum_{i=1}^{n}b_i)=n-2}(\frac{(n-2)!}{\prod_{i=1}^{n}b_i!})\times(\prod_{i=1}^{n}a_{i}^{b_i + 1})\times(\prod_{i=1}^{n}(b_{i} + 1)^{m})\times(\sum_{i=1}^{n}(b_{i} + 1)^{m})\]

其中 \(b_i + 1 = d_i\)

作一些简单的变形:
\[ans = (n-2)!\times(\prod_{i=1}^{n}a_i)\times\sum_{i=1}^{n}\sum_{(\sum_{j=1}^{n}b_j)=n-2}(\frac{(b_{i} + 1)^{2m}\times a_{i}^{b_{i}}}{b_{i}!})\times(\prod_{j=1,j\not =i}^{n}\frac{(b_{j} + 1)^{2m}\times a_{j}^{b_{j}}}{b_{j}!})\]

引入生成函数。如果记 \(P(x) = \sum_{i=0}\frac{(i + 1)^{2m}\times x^i}{i!}\)\(Q(x) = \sum_{i=0}\frac{(i + 1)^{m}\times x^i}{i!}\),则:
\[ans = (n-2)!\times(\prod_{i=1}^{n}a_i)\times([x^{n-2}]\sum_{i=1}^{n}P(a_i x)\times(\prod_{j=1,j\not =i}^{n}Q(a_j x)))\\ ans = (n-2)!\times(\prod_{i=1}^{n}a_i)\times([x^{n-2}]\prod_{i=1}^{n}Q(a_i x)\times\sum_{i=1}^{n}\frac{P(a_i x)}{Q(a_i x)})\]

注意到 \(\frac{P(a_i x)}{Q(a_i x)}\) 其实就是 \(\frac{P(x)}{Q(x)}\) 的第 k 项乘上 \(a_i^{k}\)

也就是说 \(\sum_{i=1}^{n}\frac{P(a_i x)}{Q(a_i x)}\) 就是 \(\frac{P(x)}{Q(x)}\) 的第 k 项乘上 \(\sum_{i=1}^{n}a_i^{k}\),而 \(\sum_{i=1}^{n}a_i^{k}\) 是可以快速求出的(在补充部分介绍)。

尝试把 \(\prod_{i=1}^{n}Q(a_i x)\) 也化成加法形式:利用对数,可以得到 \(\prod_{i=1}^{n}Q(a_i x) = \exp(\ln(\sum_{i=1}^{n}Q(a_i x)))\)

之后就没有了。只要求出了 \(\sum_{i=1}^{n}a_i^{k}\),剩下的都是模板。

@补充@

关于如何求 \(\sum_{i=1}^{n}a_i^{k}\),其实方法比较多,这里介绍一种:

注意到 \(\ln(1 - x) = -\sum_{i=1}\frac{x^i}{i}\),那么只要求出 \(\sum_{i=1}^{n}\ln(1 - a_ix)\),也就求出了 \(\sum_{i=1}^{n}a_i^{k}\)

利用对数的性质,有 \(\sum_{i=1}^{n}\ln(1 - a_ix) = \ln(\prod_{i=1}^{n}(1 - a_ix))\)

然后里面那个式子分治 fft 可以 O(nlog^2n) 搞定,这样一来总时间复杂度其实就是 O(nlog^2n)。

@accepted code@

#include <cstdio>
#include <algorithm>
using namespace std;

const int MAXN = 4*30000;
const int MOD = 998244353;

struct mint{
    int x;
    mint(int _x=0) : x(_x) {}
    friend mint operator + (mint a, mint b) {
        return a.x + b.x >= MOD ? a.x + b.x - MOD : a.x + b.x;
    }
    friend mint operator - (mint a, mint b) {
        return a.x - b.x < 0 ? a.x - b.x + MOD : a.x - b.x;
    }
    friend mint operator * (mint a, mint b) {
        return (int)(1LL * a.x * b.x % MOD);
    }
    friend mint pow_mod(mint b, int p) {
        mint ret = 1;
        while( p ) {
            if( p & 1 ) ret *= b;
            b *= b;
            p >>= 1;
        }
        return ret;
    }
    friend mint operator / (mint a, mint b) {
        return a * pow_mod(b, MOD - 2);
    }
    friend void operator += (mint &a, mint b) {a = a + b;}
    friend void operator -= (mint &a, mint b) {a = a - b;}
    friend void operator *= (mint &a, mint b) {a = a * b;}
    friend void operator /= (mint &a, mint b) {a = a / b;}
};

namespace poly{
    const mint G = 3;   
    mint w[20], iw[20], inv[MAXN + 5];
    void init() {
        for(int i=0;i<20;i++) {
            w[i] = pow_mod(G, (MOD-1)/(1<<i));
            iw[i] = pow_mod(w[i], MOD-2);
        }
        inv[1] = 1;
        for(int i=2;i<=MAXN;i++)
            inv[i] = 0 - (MOD/i)*inv[MOD%i];
    }
    void debug(mint *A, int n) {
        for(int i=0;i<n;i++)
            printf("%d ", A[i].x);
        puts("");
    }
    void ntt(mint *A, int n, int type) {
        for(int i=0,j=0;i<n;i++) {
            if( i < j ) swap(A[i], A[j]);
            for(int k=(n>>1);(j^=k)<k;k>>=1);
        }
        for(int i=1;(1<<i)<=n;i++) {
            int s = (1 << i), t = (s >> 1);
            mint u = (type == 1 ? w[i] : iw[i]);
            for(int j=0;j<n;j+=s) {
                mint p = 1;
                for(int k=0;k<t;k++,p*=u) {
                    mint x = A[j+k], y = A[j+k+t];
                    A[j+k] = x + y*p, A[j+k+t] = x - y*p;
                }
            }
        }
        if( type == -1 ) {
            mint iv = inv[n];
            for(int i=0;i<n;i++)
                A[i] *= iv;
        }
    }
    int length(int n) {
        int len; for(len = 1; len < n; len <<= 1);
        return len;
    }
    void pcopy(mint *A, mint *B, int n, int l) {
        for(int i=0;i<n;i++) A[i] = B[i];
        for(int i=n;i<l;i++) A[i] = 0;
    }
    mint t1[MAXN + 5], t2[MAXN + 5];
    void pmul(mint *A, int nA, mint *B, int nB, mint *C) {
        int len = length(nA + nB - 1);
        pcopy(t1, A, nA, len), ntt(t1, len, 1);
        pcopy(t2, B, nB, len), ntt(t2, len, 1);
        for(int i=0;i<len;i++) C[i] = t1[i] * t2[i];
        ntt(C, len, -1);
    }
    mint t3[MAXN + 5], t4[MAXN + 5];
    void pinv(mint *A, mint *B, int n) {
        if( n == 1 ) {
            B[0] = 1 / A[0];
            return ;
        }
        int m = (n + 1) >> 1; pinv(A, B, m);
        int len = length(n << 1);
        pcopy(t3, A, n, len), ntt(t3, len, 1);
        pcopy(t4, B, m, len), ntt(t4, len, 1);
        for(int i=0;i<len;i++) B[i] = t4[i]*(2 - t3[i]*t4[i]);
        ntt(B, len, -1);
    }
    void pdif(mint *A, mint *B, int n) {
        for(int i=1;i<n;i++)
            B[i-1] = A[i] * i;
    }
    void pint(mint *A, mint *B, int n) {
        for(int i=n-1;i>=0;i--)
            B[i+1] = A[i] / (i + 1);
        B[0] = 0;
    }
    mint t5[MAXN + 5], t6[MAXN + 5];
    void pln(mint *A, mint *B, int n) {
        pinv(A, t5, n), pdif(A, t6, n);
        pmul(t5, n, t6, n, B);
        pint(B, B, n);
    }
    mint t7[MAXN + 5], t8[MAXN + 5];
    void pexp(mint *A, mint *B, int n) {
        if( n == 1 ) {
            B[0] = 1;
            return ;
        }
        int m = (n + 1) >> 1; pexp(A, B, m);
        int len = length(n << 1);
        pcopy(t7, B, m, len), pln(t7, t8, n), pcopy(t7, t8, n, len);
        pcopy(t8, B, m, len);
        for(int i=0;i<n;i++) t7[i] = A[i] - t7[i];
        t7[0] = t7[0] + 1;
        ntt(t7, len, 1), ntt(t8, len, 1);
        for(int i=0;i<len;i++) B[i] = t7[i] * t8[i];
        ntt(B, len, -1);
    }
}

int n, m, k;

mint A[MAXN + 5], B[MAXN + 5];
void init() {
    mint t = 1;
    for(int i=0;i<n;i++,t*=i) {
        mint a = 1 / t, b = pow_mod(mint(i + 1), m);
        A[i] = a * b * b, B[i] = a * b;
    }
    poly::init();
}

mint a[MAXN + 5], f[MAXN + 5], s[MAXN + 5];
int solve(int l, int r) {
    if( l == r ) {
        f[l<<1] = 1, f[l<<1|1] = 0 - a[l];
        return 2;
    }
    int mid = (l + r) >> 1;
    int ls = solve(l, mid), rs = solve(mid + 1, r);
    poly::pmul(f + (l<<1), ls, f + ((mid + 1) << 1), rs, f + (l << 1));
    return ls + rs - 1;
}
void get_pow_sum() {
    solve(0, n - 1), poly::pln(f, s, n + 1);
    s[0] = n;
    for(int i=1;i<=n;i++)
        s[i] = 0 - s[i]*i;
}

mint t1[MAXN + 5], t2[MAXN + 5];
int main() {
    scanf("%d%d", &n, &m), k = n - 2, init();
    for(int i=0;i<n;i++) scanf("%d", &a[i].x);
    
    get_pow_sum();
    poly::pln(B, t1, n);
    for(int i=0;i<n;i++)
        t1[i] *= s[i];
    poly::pexp(t1, t2, n);
    poly::pinv(B, t1, n);
    poly::pmul(A, n, t1, n, t1);
    for(int i=0;i<n;i++)
        t1[i] *= s[i];
    poly::pmul(t1, n, t2, n, t1);
    mint ans = t1[n - 2];
    for(int i=0;i<n;i++) ans *= a[i];
    for(int i=1;i<=n-2;i++) ans *= i;
    printf("%d\n", ans.x);
}

@details@

顺带一提,这道题还有依赖于斯特林数的 O(nmlogn) 的做法(但是我看不懂 QaQ)。

猜你喜欢

转载自www.cnblogs.com/Tiw-Air-OAO/p/12119237.html
今日推荐