[XSY 2666][分治FFT][容斥]排列问题

在这里插入图片描述
在这里插入图片描述
yy一个方案有多少对相邻的相同,显然是m-颜色段数
我们令 f i , j f_{i,j} 表示前i种颜色分成j段颜色段数的方案数。
注意,我们这里允许两个同样颜色的颜色段相邻。
可以得到转移方程 f i , j = k = 0 j 1 f i 1 , k ( a i 1 j k 1 ) 1 j k f_{i,j}=\sum_{k=0}^{j-1}f_{i-1,k}*\binom{a_{i}-1}{j-k-1}*\frac{1}{j-k}
可以分治FFT完成。
当然,这求出来的不是真正的f。
我们需要执行 f n , i = i ! f n , i ( i ϵ [ 1 , m ] ) f_{n,i}=i!*f_{n,i}(i\epsilon [1,m])
不妨令 g i = f n , i g_{i}=f_{n,i} , a n s i ans_{i} 为分成i段的真正答案(即不允许有同样颜色的颜色段相邻)
可以得到容斥式子:
a n s i = g i j = 1 i 1 a n s j ( m j i j ) ans_{i}=g_{i}-\sum_{j=1}^{i-1}ans_{j}*\binom{m-j}{i-j}
考虑继续化简。
可以发现, g j g_{j} a n s i ans_{i} 的贡献为 ( 1 ) i j ( m j i j ) (-1)^{i-j}*\binom{m-j}{i-j}
证明如下(采用归纳证明):
g k > a n s i g_{k}->ans_{i}
= j = k i 1 ( 1 ) j k + 1 ( m k j k ) ( m j i j ) =\sum_{j=k}^{i-1}(-1)^{j-k+1}*\binom{m-k}{j-k}*\binom{m-j}{i-j}
= j = k i 1 ( 1 ) j k + 1 ( m k ) ! ( m j ) ! ( j k ) ! ( m j ) ! ( m i ) ! ( i j ) ! =\sum_{j=k}^{i-1}(-1)^{j-k+1}*\frac{(m-k)!}{(m-j)!(j-k)!}*\frac{(m-j)!}{(m-i)!(i-j)!}
= j = k i 1 ( 1 ) j k + 1 ( m k ) ! ( j k ) ! ( m i ) ! ( i j ) ! =\sum_{j=k}^{i-1}(-1)^{j-k+1}*\frac{(m-k)!}{(j-k)!(m-i)!(i-j)!}
= ( m k ) ! ( m i ) ! ( i k ) ! j = k i 1 ( 1 ) j k + 1 ( i k ) ! ( i j ) ! ( j k ) ! =\frac{(m-k)!}{(m-i)!(i-k)!}\sum_{j=k}^{i-1}(-1)^{j-k+1}*\frac{(i-k)!}{(i-j)!(j-k)!}
= ( m k i k ) j = k i 1 ( 1 ) j k + 1 ( i k j k ) =\binom{m-k}{i-k} \sum_{j=k}^{i-1}(-1)^{j-k+1}*\binom{i-k}{j-k}
= ( 1 ) i k ( m k i k ) =(-1)^{i-k}*\binom{m-k}{i-k}
来次NTT即可。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;
const int Mod=998244353;
int n,m,q;
#define Maxn 200010
int a[Maxn];
int fact[Maxn],inv[Maxn];
inline int C(int i,int j){return 1ll*fact[i]*inv[i-j]%Mod*inv[j]%Mod;}

int A[Maxn<<2],B[Maxn<<2];
int rev[Maxn<<2],len,bit;
inline int FP(int a,int b){
    int ans=1;
    while(b){
        if(b&1)ans=1ll*ans*a%Mod;
        a=1ll*a*a%Mod;
        b>>=1;
    }
    return ans;
}
inline void NTT(int *A,int t){
    for(int i=0;i<len;++i)
        if(i<rev[i])swap(A[i],A[rev[i]]);
    for(int i=1;i<len;i<<=1){
        int gn=FP(3,(t*(Mod-1)/(i<<1)+(Mod-1))%(Mod-1));
        for(int j=0;j<len;j+=i<<1){
            int g=1;
            for(int k=0;k<i;++k){
                int x=A[j+k];
                int y=1ll*g*A[j+k+i]%Mod;
                A[j+k]=(x+y)%Mod;
                A[j+k+i]=(x-y+Mod)%Mod;
                g=1ll*g*gn%Mod;
            }
        }
    }
    if(t==-1){
        int Inv=FP(len,Mod-2);
        for(int i=0;i<len;++i)A[i]=1ll*A[i]*Inv%Mod;
    }
}
inline void Mul(int *A,int *B){
    NTT(A,1);NTT(B,1);
    for(int i=0;i<len;++i)A[i]=1ll*A[i]*B[i]%Mod;
    NTT(A,-1);
}

vector<int> poly[Maxn<<2];
void solve(int k,int l,int r){
    if(l==r){
        poly[k].push_back(0);
        for(int i=1;i<=a[l];++i)poly[k].push_back(1ll*C(a[l]-1,i-1)*inv[i]%Mod);
        return;
    }
    int mid=(l+r)>>1;
    solve(k<<1,l,mid);
    solve(k<<1|1,mid+1,r);
    int l1=poly[k<<1].size()-1,l2=poly[k<<1|1].size()-1;
    for(int i=0;i<=l1;++i)A[i]=poly[k<<1][i];
    for(int i=0;i<=l2;++i)B[i]=poly[k<<1|1][i];
    len=1;bit=0;
    while(len<=l1+l2)len<<=1,bit++;
    for(int i=l1+1;i<len;++i)A[i]=0;
    for(int i=l2+1;i<len;++i)B[i]=0;
    for(int i=0;i<len;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    Mul(A,B);
    for(int i=0;i<=l1+l2;++i)poly[k].push_back(A[i]);
}

inline void rd(int &x){
    x=0;char ch=getchar();
    while(ch<'0'||ch>'9')ch=getchar();
    while(ch>='0'&&ch<='9'){
        x=x*10+ch-'0';
        ch=getchar();
    }
}

int main(){
    rd(n);m=0;
    for(register int i=1;i<=n;++i){
        rd(a[i]);
        m+=a[i];
    }
    fact[0]=1;
    for(register int i=1;i<=m;++i)fact[i]=1ll*fact[i-1]*i%Mod;
    inv[0]=inv[1]=1;
    for(register int i=2;i<=m;++i)inv[i]=1ll*(Mod-Mod/i)*inv[Mod%i]%Mod;
    for(register int i=2;i<=m;++i)inv[i]=1ll*inv[i-1]*inv[i]%Mod;
    solve(1,1,n);
    for(register int i=0;i<=m;++i)A[i]=1ll*A[i]*fact[i]%Mod*fact[m-i]%Mod;
    for(register int i=0;i<=m;++i){
        if(i&1)B[i]=Mod-inv[i];
        else B[i]=inv[i];
    }
    len=1;bit=0;
    while(len<=2*m)len<<=1,bit++;
    for(register int i=0;i<=len;++i)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
    Mul(A,B);
    for(register int i=1;i<=m;++i)A[i]=1ll*A[i]*inv[m-i]%Mod;
    rd(q);
    int x;
    while(q--){
        rd(x);
        printf("%d\n",A[m-x]);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/ezoilearner/article/details/84704313