习题:唱、跳、rap和篮球(生成函数)

题目

传送门

思路

很容易可以发现正向来求是瞬间爆炸的

正难♂则反

于是我们思考总的方案数来求不符合的方案数来进行容斥

我们假设至少不相交的i段是唱跳rap篮球

我们考虑枚举这i个区间的起点位置,那么总方案数是\(C_{n-3i}^i\)种方案

为什么是3*i呢?

题意中有明确只要连续的4个人,并且只要这4个人当中集齐了唱跳rap篮球就行,对顺序没有要求

也就是我们对于一个区间,我们的后三个位置不能再次选择,因为选择之后就会违背我们对枚举的i段的定义

这里说的不能再次选择,指的是某一段唱跳rap篮球的开头不在这三个位置

我们排除这些位置后总的能选择的位置就是\(n-4i\),这些位置是能随便乱放的

尝试写出生成函数,

对于第k种学生,如果选了\(t_k\)

\(G_k(x)=\sum_{i=0}^{t_k}\frac{1}{i!}x^i\)

注意此时我们只是选出学生,

只有对于学生构成不同的才算作不同的方案,

我们首先将其卷起来,

\(F(x)=\prod_{i=1}^{4}G_i(x)\)

又因为我们只是选出来,

我们还需要将其进行排序,答案即为

\((n-4*i)!F(x)\)

对于后面的累乘,我们需要的只是\(x^{n-4i}\)项的系数

所以答案为\((n-4i)f_{n-4i}\)

至于容斥,就按一般的容斥写法写就行了

代码

#include<iostream>
#include<algorithm>
#include<vector>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
namespace polynomial
{ 
    const int N=1e3+10;
    const int M=N<<3;
    const int mod=998244353;
    const int G=3;
    const int SIZE=sizeof(int);
    #define poly vector<int>
    #define int long long
    int w[M],rev[M];
    poly resize(poly f,int n)
    {
        f.resize(n);
        return f;
    }
    int sub(int a,int b)
    {
        a-=b;
        if(a<0)
            return a+mod;
        else
            return a;
    }
    int add(int a,int b)
    {
        a+=b;
        if(a>=mod)
            return a-mod;
        else
            return a;
    }
    int inv(int x)
    {
        if(x<2)
            return 1;
        else
            return (1ll*mod-mod/x)*inv(mod%x)%mod;
    }
    int qkpow(int a,int b)
    {
        if(b==0)
            return 1;
        if(b==1)
            return a;
        int t=qkpow(a,b/2);
        t=(1ll*t*t)%mod;
        if(b%2==1)
            t=(1ll*t*a)%mod;
        return t;
    }
    void ntt(int *a,int lim)
    {
        for(int i=0;i<lim;i++)
            if(i<rev[i])
                swap(a[i],a[rev[i]]);
        for(int len=1;len<lim;len<<=1)
        {
            for(int i=0;i<lim;i+=(len<<1))
            {
                for(int j=0;j<len;j++)
                {
                    int x=a[i+j];
                    int y=1ll*w[j+len]*a[i+j+len]%mod;
                    a[i+j]=add(x,y);
                    a[i+j+len]=sub(x,y);
                }
            }
        }
    }
    void ntt_init()
    {
        int wn;
        for(int len=1;(len<<1)<M;len<<=1)
        {
            wn=qkpow(G,(mod-1)/(len<<1));
            w[len]=1;
            for(int i=1;i<len;i++)
                w[i+len]=1ll*w[i+len-1]*wn%mod;
        }
    }
    int init(int len)
    {
        int lim=1;
        int k=0;
        while(lim<len)
        {
            lim<<=1;
            k++;
        }
        for(int i=0;i<lim;i++)
            rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
        return lim;
    }
    poly operator * (const poly &f,const poly &g)
    {
        int a[M]={};
        int b[M]={};
        int lim=init(f.size()+g.size()-1);
        int inv_len=inv(lim);
        poly ret;  
        for(int i=0;i<f.size();i++)
            a[i]=f[i];
        for(int i=0;i<g.size();i++)
            b[i]=g[i];
        ntt(a,lim);
        ntt(b,lim);
        for(int i=0;i<lim;i++)
            a[i]=1ll*a[i]*b[i]%mod;
        reverse(a+1,a+lim);
        ntt(a,lim);
        for(int i=0;i<f.size()+g.size()-1;i++)
            ret.push_back(1ll*a[i]*inv_len%mod);
        return ret;
    }
    #undef int
    #undef poly
};
using namespace polynomial;
int n;
int num[5];
long long fac[1005];
long long ans;
long long c[1005][1005];
long long solve_c(long long n,long long m)
{
    if(c[n][m])
        return c[n][m];
    long long ret=1;
    if(m>n/2)
        m=n-m;
    for(long long i=n;i>=n-m+1;i--)
        ret=(ret*i%mod)*qkpow(n-i+1,mod-2)%mod;
    return c[n][m]=ret;
}
void prepare()
{
    fac[0]=1;
    for(int i=1;i<=n;i++)
        fac[i]=fac[i-1]*i%mod;
}
long long solve(int k)
{
    vector<long long> a[5];
    for(int i=1;i<=4;i++)
        for(int j=0;j<=num[i]-k;j++)
            a[i].push_back(qkpow(fac[j],mod-2));
    vector<long long> g;
    g=a[1]*a[2];
    g=g*a[3];
    g=g*a[4];
    return g[n-4*k]*fac[n-4*k]%mod;
}
int main()
{
    ios::sync_with_stdio(false);
    ntt_init();
    cin>>n;
    prepare();
    for(int i=1;i<=4;i++)
        cin>>num[i];
    sort(num+1,num+5);
    for(int i=0;(i<<2)<=min(n,num[1]<<2);i++)
    {
        long long t=solve(i)*solve_c(n-3*i,i)%mod;
        //cout<<solve(i)<<' '<<solve_c(n-3*i,i)<<'\n';
        if(i%2==1)
            ans-=t;
        else
            ans+=t;
        ans=ans%mod;
    }
    cout<<(ans+mod)%mod;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/loney-s/p/12109326.html
RAP