CTS2019 珍珠 容斥、生成函数、多项式求逆

没有传送门

题目大意:给出一个长度为\(n\)的序列\(a_i\),序列中每一个数可以取\(1\)\(D\)中的所有数。问共有多少个序列满足:设\(p_i\)表示第\(i\)个数在序列中出现的次数,\(\sum\limits_{i=1}^D \lfloor \frac{p_i}{2} \rfloor \geq m\)\(D \leq 10^5 , 0 \leq m \leq n \leq 10^9\)


在有生之年切掉laofu的多项式题,全场唯一一个写多项式求逆的,其他人都直接卷积

首先上面的条件等价于:\(\sum\limits_{i=1}^D [2 \not\mid p_i] \leq n - 2m\)。那么一种想法是求出强制其中\(n - 2m + 1\)个数字出现次数为奇数,其他的数出现次数随意。那么这样的方案数是\(\binom{D}{n - 2m + 1} [x^n](\frac{e^x - e^{-x}}{2})^{n - 2m + 1} e^{(D - (n - 2m + 1))x}\)。但是注意到当出现奇数次的\(p_i\)的个数\(> n - 2m + 1\)的时候在这种情况下会被算重,而且算重的个数还不同,这暗示着需要进行容斥。

先做几个特判:\(n < 2m\)时答案为\(0\)\(D < n - 2m + 1\)时答案为\(D^n\)

不妨设\(f_i\)表示强制其中\(i\)个数字出现次数为奇数,其他的数出现次数随意的方案数,那么\(f_i = \binom{D}{i} [x^n](\frac{e^x - e^{-x}}{2})^{i} e^{(D - i)x}\),经过化简可以得到\(f_i = i! \binom{D}{i} \frac{1}{2^i} \sum\limits_{j=0}^i \frac{(-1)^j (D - 2j)^n}{(i-j)!j!}\)。不难发现后面是一个卷积形式,使用\(NTT\)\(O(DlogD)\)的时间复杂度内可以求出所有的\(f_i\)

然后又设\(g_i\)表示恰好\(i\)个数字出现奇数次的方案数。然后就和HAOI2018 染色的多项式求逆做法相同。

最后答案就是\(\sum\limits_{i=0}^{n - 2m} g_i\)

因为多项式求逆的做法太不优秀了,所以\(D=100000\)在CTS考场要跑0.95s……

#include<bits/stdc++.h>
using namespace std;
const int MOD = 998244353 , _ = (1 << 19) + 3 , INV2 = (MOD + 1) >> 1;

int poww(long long a , int b){
    a = (a % MOD + MOD) % MOD;
    int times = 1;
    while(b){
        if(b & 1) times = times * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return times;
}

int jc[100007] , inv[100007] , D , N , M;

namespace poly{
    const int g = 3 , INV = 332748118;
    int dir[_] , need , invnd;
    int A[_] , B[_] , C[_];
    
    void init(int len){
        need = 1;
        while(need < len) need <<= 1;
        for(int i = 1 ; i < need ; ++i)
            dir[i] = (dir[i >> 1] >> 1) | (i & 1 ? need >> 1 : 0);
        invnd = poww(need , MOD - 2);
    }
    
    void NTT(int *arr , int tp){
        for(int i = 1 ; i < need ; ++i)
            if(i < dir[i]) arr[i] ^= arr[dir[i]] ^= arr[i] ^= arr[dir[i]];
        for(int i = 1 ; i < need ; i <<= 1){
            int wn = poww(tp == 1 ? g : INV , (MOD - 1) / i / 2);
            for(int j = 0 ; j < need ; j += i << 1){
                long long w = 1;
                for(int k = 0 ; k < i ; ++k , w = w * wn % MOD){
                    int x = arr[j + k] , y = arr[i + j + k] * w % MOD;
                    arr[j + k] = x + y >= MOD ? x + y - MOD : x + y;
                    arr[i + j + k] = x < y ? x + MOD - y : x - y;
                }
            }
        }
    }

#define clr(x) memset(x , 0 , sizeof(int) * (need))
    void getInv(int *a , int *b , int len){
        if(len == 1){
            b[0] = poww(a[0] , MOD - 2);
            return;
        }
        getInv(a , b , (len + 1) >> 1);
        init(len * 2 + 5);
        memcpy(A , a , sizeof(int) * (len + 1));
        memcpy(B , b , sizeof(int) * (len + 1));
        NTT(A , 1); NTT(B , 1);
        for(int i = 0 ; i < need ; ++i)
            A[i] = 1ll * A[i] * B[i] % MOD * B[i] % MOD;
        NTT(A , 0);
        for(int i = 0 ; i < len ; ++i)
            b[i] = (2ll * b[i] - 1ll * A[i] * invnd % MOD + MOD) % MOD;
        clr(A); clr(B);
    }
}
using namespace poly;

void init(){
    jc[0] = 1;
    for(int i = 1 ; i <= D ; ++i)
        jc[i] = 1ll * jc[i - 1] * i % MOD;
    inv[D] = poww(jc[D] , MOD - 2);
    for(int i = D - 1 ; i >= 0 ; --i)
        inv[i] = inv[i + 1] * (i + 1ll) % MOD;
    
}

int binom(int x , int y){
    return x < y ? 0 : 1ll * jc[x] * inv[y] % MOD * inv[x - y] % MOD;
}

int F[_] , G[_] , H[_];

signed main(){
    freopen("pearl.in","r",stdin);
    freopen("pearl.out","w",stdout);
    cin >> D >> N >> M;
    if(N < 2 * M) cout << 0;
    else
        if(D - N + 2 * M - 1 < 0 || M == 0)
            cout << poww(D , N);
        else{
            init();
            for(int i = 0 ; i <= D ; ++i){
                F[i] = ((i & 1 ? -1ll : 1ll) * poww((D - 2 * i + MOD) % MOD , N) * inv[i] % MOD + MOD) % MOD;
                G[i] = inv[i];
            }
            init(2 * D + 2); NTT(F , 1); NTT(G , 1);
            for(int i = 0 ; i < need ; ++i)
                F[i] = 1ll * F[i] * G[i] % MOD;
            NTT(F , 0);
            for(int i = 0 ; i < need ; ++i)
                if(i <= D)
                    F[i] = 1ll * F[i] * invnd % MOD * poww(INV2 , i) % MOD * jc[i] % MOD * binom(D , i) % MOD;
                else F[i] = 0;
            clr(G);
            for(int i = 0 ; i <= D ; ++i)
                F[i] = 1ll * F[i] * jc[i] % MOD;
            reverse(F , F + D + 1);
            for(int i = 0 ; i <= D ; ++i)
                H[i] = inv[i];
            getInv(H , G , D + 1);
            init(2 * D + 2); NTT(F , 1); NTT(G , 1);
            for(int i = 0 ; i < need ; ++i)
                F[i] = 1ll * F[i] * G[i] % MOD;
            NTT(F , 0);
            reverse(F , F + D + 1);
            for(int i = 0 ; i <= D ; ++i)
                F[i] = 1ll * F[i] * invnd % MOD * inv[i] % MOD;
            int sum = 0;
            for(int i = 0 ; i <= D ; ++i)
                if(i >= N - 2 * M + 1)
                    sum = (sum + F[i]) % MOD;
            cout << (poww(D , N) - sum + MOD) % MOD;
        }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Itst/p/10858580.html
今日推荐