【AGC005F】Many Easy Problems (NTT)

Description

​ 给你一棵\(~n~\)个点的树和一个整数。设为\(~S~\)为树上某些点的集合,定义\(~f(S)~\)为最小的包含\(~S~\)的联通子图的大小。\(~n~\)个点选\(~k~\)个点一共有\(~C_n^k~\)种方案,请你求出所有方案的\(~f(S)~\))的和, 对\(~924844033~\)取模。

​ 求所有\(~k \in [1, ~n]~\)的答案。
看题戳我

Solution

​ 首先看到这道题,根本不会快速求\(~f(S)~\),所以换一个角度,考虑每个点对于答案的贡献。不难发现, 对于单独一个\(~k~\),一个点\(~u~\)会产生贡献当且仅当这\(~k~\)个点不全在以\(~u~\)的相邻节点为根的子树中,根据容斥可以得到一个点对一个\(~k~\)的贡献为\(~C_n^k - \sum_{v \in {nex_u}} ^{} {C_{siz_v}^k}~\),观察这个式子,可以发现计算总贡献时每个点的子树大小会被计算两次,一个是本身子树大小\(~siz_u~\),一个是\(~n - siz_u~\),而 \(~C_n^k~\)被计算了\(~n~\)次,所以有
\[ Ans_k = n \times {n \choose k} - \sum_{i = 1} ^ n {num_i \times {i \choose k}} \]
​ 其中,\(num_i~\)表示子树大小为\(~i~\)的子树个数,这样已经可以卷后半部分了,但是我们想要一个更简便的式子。

定义一个新的\(~cnt_i~\)表示
\[ cnt_i = \begin{cases} n, ~ i = n\\ -num_i, ~ i \neq n \end{cases} \]
所以上面的式子可以更简便的表示为
\[ Ans_k = {\sum_{i = 1}^{n} cnt_i \times {i \choose k}} = \frac{1}{k!}~{\sum_{i = 1}^{n}} ~\frac{cnt_i \times i!}{(i - k)!} \]
​ 那么把\(~cnt_i \times i!~\)放一起,\(~\frac{1}{(i - k)!}~\)放一起, 用一个\(~FFT~\)套路把\(~(i - k)~\)倒过来之后就可以卷起来了。

​ 但是为了求了这个更简便的式子会导致\(~cnt_i \times i!~\)可能是负数,而我的\(~NTT~\)已经习惯了这样写,look down,因为普通题目中要卷起来的一般都是正的,所以一开始就把这题要卷的东西变成正的也是可以的。

a[j + k] = (x + y) % mod, a[j + k + (i >> 1)] = (x - y + mod) % mod; 

​ 而我一开始没有转正,所以这样写很不优秀,因为一旦\(~y~\)是一个比较小的负数,那么\(~(x - y + mod)~\)就爆\(~int~\)了,我因为这里调了一个晚上+一个下午,很难受。

最后提一下这个题的模数是\(~924844033~\),所以原根是\(~5~\)而不是熟知的\(~3~\)

Code

#include<bits/stdc++.h>
#define For(i, j, k) for(int i = j; i <= k; ++i)
#define Forr(i, j, k) for(int i = j; i >= k; --i)
#define Travel(i, u) for(int i = beg[u], v = to[i]; i; i = nex[i], v = to[i])
using namespace std;

inline int read() {
    int x = 0, p = 1; char c = getchar();
    for(; !isdigit(c); c = getchar()) if(c == '-') p = -1;
    for(; isdigit(c); c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
    return x *= p;
}

inline void File() {
#ifndef ONLINE_JUDGE
    freopen("AGC005F.in", "r", stdin);
    freopen("AGC005F.out", "w", stdout);
#endif
}

const int N = 2e5 + 10, maxn = N << 2, mod = 924844033;
int a[maxn], b[maxn], e = 1, beg[N], nex[N << 1], to[N << 1]; 
int rev[maxn], bit, len, siz, invg[maxn], powg[maxn]; 
int fac[N], inv[N], cnt[N], sz[N], u, v, n;

inline int qpow(int a, int b) {
    int res = 1;
    for (; b; a = 1ll * a * a % mod, b >>= 1) 
        if (b & 1) res = 1ll * res * a % mod;
    return res;
}

inline void Init(int n) {
    fac[0] = inv[0] = 1;
    For(i, 1, n) fac[i] = 1ll * i * fac[i - 1] % mod;
    inv[n] = qpow(fac[n], mod - 2);
    Forr(i, n - 1, 0) inv[i] = 1ll * inv[i + 1] * (i + 1) % mod;
}

inline void add(int x, int y) {
    to[++ e] = y, nex[e] = beg[x], beg[x] = e;
    to[++ e] = x, nex[e] = beg[y], beg[y] = e;
}

inline void dfs(int u, int f) {
    sz[u] = 1;
    Travel(i, u) if (v != f) dfs(v, u), sz[u] += sz[v];
    -- cnt[sz[u]], -- cnt[n - sz[u]];
}

inline void NTT(int *a, int flag) {
    For(i, 0, siz - 1) if (rev[i] > i) swap(a[rev[i]], a[i]);
    for (int i = 2; i <= siz; i <<= 1) {
        int wn = flag ? powg[i] : invg[i];
        for (int j = 0; j < siz; j += i) {
            int w = 1;
            for (int k = 0; k < (i >> 1); ++ k, w = 1ll * w * wn % mod) {
                int x = a[j + k], y = 1ll * w * a[j + k + (i >> 1)] % mod;
                a[j + k] = (x + y) % mod, a[j + k + (i >> 1)] = (x - y) % mod; 
            }
        }
    }
    if (!flag) {
        int g = qpow(siz, mod - 2);
        For(i, 0, siz - 1) a[i] = 1ll * g * a[i] % mod;
    }
}

int main() {
    File(), Init(N - 5);
    n = read();
    For(i, 2, n) u = read(), v = read(), add(u, v);
    dfs(1, 0), cnt[n] = n;

    for (siz = 1; siz <= (n << 1); siz <<= 1) ++ bit;
    For(i, 0, siz - 1) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));

    int g = qpow(5, mod - 2);
    for(int i = 1; i <= siz; i <<= 1) {
        invg[i] = qpow(g, (mod - 1) / i),
        powg[i] = qpow(5, (mod - 1) / i);
    }

    For(i, 0, n) {
        a[i] = 1ll * cnt[i] * fac[i] % mod;
        b[i] = inv[n - i];
    }

    NTT(a, 1), NTT(b, 1);
    For(i, 0, siz) a[i] = 1ll * a[i] * b[i] % mod;
    NTT(a, 0);

    For(i, 1, n) {
        int ans = 1ll * a[n + i] * inv[i] % mod;
        ans = (ans + mod) % mod;
        printf("%d\n", ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/LSTete/p/9506171.html
NTT