[HAOI2018] 染色 - NTT,组合计数

对一个长度为 \(n\) 的序列进行染色,有 \(m\) 种颜色。对一种方案,如果恰好出现 \(s\) 次的颜色总数为 \(k\),则得分为 \(W_k\),求所有染色方案得分的总和。\(n\leq 10^7,m\leq 10^5,s \leq 150\)

Solution

最大有效颜色数为 \(lim=\min(m,n/s)\)

设恰好出现 \(s\) 次的颜色有至少 \(i\) 种的方案数位 \(f[i]\),则选出这 \(i\) 种颜色,并给他们分配位置,剩下的相互独立填入即可,即
\[ f[i]=C_m^i \frac{n!}{(S!)^i (n-iS)!}(m-i)^{n-iS} \]
\(ans[i]\) 表示出现 \(s\) 次的颜色恰好有 \(i\) 种的方案数,由容斥原理,
\[ ans[i]=\sum_{j=i}^{lim} (-1)^{j-i} C_j^i f[j] \]
为了方便做卷积,变形为
\[ ans[i]\cdot i!=\sum_{j=i}^{lim} \frac{(-1)^{j-i}}{(j-i)!} f[j]\cdot j! \]
不妨令
\[ A[k]=\frac{(-1)^k}{k!} \quad \quad B[k]=f[k]\cdot k! \]
那么卷积就可以描述为
\[ C[i]=\sum_{j-k=i} A[k]B[j] \]
很自然地,设 \(D[k]=A[lim-k]\),则
\[ C[i]=\sum_{j-k=i} D[lim-k]B[j]=\sum_{j+u=lim+i} D[u]B[j]=\sum_{j+k=lim+i}B[j]D[k] \]
答案就是
\[ \sum_{i=0}^{lim} ans[i]\cdot W_i \]

#include <bits/stdc++.h>
using namespace std;
#define pw(n) (1<<n)
using namespace std;
#define int long long
namespace NTT {
    const int N=4000005;
    const int mod=1004535809,g=3;
    int n,m,bit,bitnum,a[N+5],b[N+5],rev[N+5];
    void getrev(int l){
        for(int i=0;i<pw(l);i++){
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(l-1));
        }
    }
    int fastpow(int a,int b){
        int ans=1;
        for(;b;b>>=1,a=1LL*a*a%mod){
            if(b&1)ans=1LL*ans*a%mod;
        }
        return ans;
    }
    void NTT(int *s,int op){
        for(int i=0;i<bit;i++)if(i<rev[i])swap(s[i],s[rev[i]]);
        for(int i=1;i<bit;i<<=1){
            int w=fastpow(g,(mod-1)/(i<<1));
            for(int p=i<<1,j=0;j<bit;j+=p){
                int wk=1;
                for(int k=j;k<i+j;k++,wk=1LL*wk*w%mod){
                    int x=s[k],y=1LL*s[k+i]*wk%mod;
                    s[k]=(x+y)%mod;
                    s[k+i]=(x-y+mod)%mod;
                }
            }
        }
        if(op==-1){
            reverse(s+1,s+bit);
            int inv=fastpow(bit,mod-2);
            for(int i=0;i<bit;i++)a[i]=1LL*a[i]*inv%mod;
        }
    }
    void solve(vector <int> A,vector <int> B,vector <int> &C) {
        n=A.size()-1;
        m=B.size()-1;
        for(int i=0;i<=n;i++) a[i]=A[i];
        for(int i=0;i<=m;i++) b[i]=B[i];
        m+=n;
        bitnum=0;
        for(bit=1;bit<=m;bit<<=1)bitnum++;
        getrev(bitnum);
        NTT(a,1);
        NTT(b,1);
        for(int i=0;i<bit;i++)a[i]=1LL*a[i]*b[i]%mod;
        NTT(a,-1);
        C.clear();
        for(int i=0;i<=m;i++) C.push_back(a[i]);
    }
}

struct poly {
    const int mod=1004535809;
    vector <int> a;
    void cut(int n) {
        while(a.size()>n) a.pop_back();
    }
    void load(int *x,int n) {
        a.clear();
        for(int i=0;i<n;i++) a.push_back(x[i]);
    }
    poly operator *(int b) {
        poly c=*this;
        for(int i=0;i<a.size();i++) (((c.a[i]*=b)%=mod)+=mod)%=mod;
        return c;
    }
    poly operator *(const poly &b) {
        poly c;
        NTT::solve(a,b.a,c.a);
        return c;
    }
    poly operator +(poly b) {
        int len=max(a.size(),b.a.size());
        a.resize(len);
        b.a.resize(len);
        poly c;
        for(int i=0;i<len;i++) c.a.push_back((a[i]+b.a[i])%mod);
        return c;
    }
    poly operator -(poly b) {
        int len=max(a.size(),b.a.size());
        a.resize(len);
        b.a.resize(len);
        poly c;
        for(int i=0;i<len;i++) c.a.push_back(((a[i]-b.a[i])%mod+mod)%mod);
        return c;
    }
};

const int N = 10000005;
const int M = 2000005;
const int mod = 1004535809;
int n,m,s,w[M],frac[N],D[M],B[M],C[M],lim;

int qpow(int p,int q) {
    return (q&1?p:1) * (q?qpow(p*p%mod,q/2):1) %mod;
}
int inv(int p) {
    return qpow(p,mod-2);
}
void init_frac() {
    frac[0]=1;
    for(int i=1;i<N;i++) frac[i]=frac[i-1]*i%mod;
}
int c(int n,int m) {
    return frac[m]*inv(frac[n])%mod*inv(frac[m-n])%mod;
}

signed main() {
    ios::sync_with_stdio(false);
    init_frac();
    cin>>n>>m>>s;
    lim=min(m,n/s);
    for(int i=0;i<=m;i++) cin>>w[i];
    for(int i=0;i<=lim;i++) {
        D[i]=(((lim-i)&1?-1:1)*inv(frac[lim-i])+mod)%mod;
        B[i]=c(i,m)*frac[n]%mod*qpow(m-i,n-i*s)%mod*frac[i]%mod*
            inv(qpow(frac[s],i))%mod*inv(frac[n-i*s])%mod;
    }
    poly pB,pD;
    pB.load(B,lim+1);
    pD.load(D,lim+1);
    poly pC=pB*pD;
    int ans=0;
    for(int i=0;i<=lim;i++) ans+=pC.a[i+lim]*inv(frac[i])%mod*w[i]%mod, ans+=mod, ans%=mod;
    cout<<ans;
}

猜你喜欢

转载自www.cnblogs.com/mollnn/p/12432940.html