【XSY3147】子集计数 DFT 组合数学

题目大意

  给定一个集合 \(\{1,2,\ldots,n\}\),要求你从中选出 \(m\) 个数,且这 \(m\) 个数的和是 \(k\)。问方案数 \(\bmod 998244353\)

  \(0\leq k<n<998244353,m\leq n\)

题解

  先不考虑选的数的个数的限制。

  显然答案的 OGF 为
\[ F(x)=\prod_{i=0}^{n-1}(1+x^i) \]

  考虑答案的式子
\[ \begin{align} ans&=\frac{1}{n}\sum_{i=0}^{n-1}F(\omega_n^i)\omega_n^{-ik}\\ &=\frac{1}{n}\sum_{i=0}^{n-1}\prod_{j=0}^{n-1}(1+\omega_n^{ij})\omega_n^{-ik} \end{align} \]
  记 \(d=\gcd(n,i)\),那么
\[ \begin{align} ans&=\frac{1}{n}\sum_{i=0}^{n-1}\prod_{j=0}^{n-1}(1+\omega_n^{ij})\omega_n^{-ik}\\ &=\frac{1}{n}\sum_{i=0}^{n-1}\prod_{j=0}^{\frac{n}{d}-1}{(1+\omega_\frac{n}{d}^{ij})}^d\omega_n^{-ik}\\ \end{align} \]
  容易发现,\(\prod_{i=0}^{n-1}(1+\omega_n^i)=1-{(-1)}^n\)
\[ \begin{align} ans&=\frac{1}{n}\sum_{i=0}^{n-1}\prod_{j=0}^{\frac{n}{d}-1}{(1+\omega_\frac{n}{d}^{ij})}^d\omega_n^{-ik}\\ &=\frac{1}{n}\sum_{i=0}^{n-1}{(\prod_{j=0}^{\frac{n}{d}-1}{(1+\omega_\frac{n}{d}^{ij})})}^d\omega_n^{-ik}\\ &=\frac{1}{n}\sum_{i=0}^{n-1}{(1-{(-1)}^\frac{n}{d})}^d\omega_n^{-ik}\\ &=\frac{1}{n}\sum_{d\mid n}{(1-{(-1)}^\frac{n}{d})}^d(\sum_{\gcd(i,n)=d}\omega_n^{-ik})\\ &=\frac{1}{n}\sum_{d\mid n}{(1-{(-1)}^\frac{n}{d})}^d(\sum_{\gcd(i,\frac{n}{d})=1}\omega_\frac{n}{d}^{-ik})\\ &=\frac{1}{n}\sum_{d\mid n}{(1-{(-1)}^\frac{n}{d})}^d(\sum_{j\mid \frac{n}{d}}\mu(j)\sum_{i=0}^{\frac{n}{dj}-1}\omega_\frac{n}{dj}^{-ik})\\ &=\frac{1}{n}\sum_{d\mid n}{(1-{(-1)}^\frac{n}{d})}^d(\sum_{j\mid \frac{n}{d}}\mu(j)[\frac{n}{dj}\mid k]\frac{n}{dj})\\ \end{align} \]
  现在要加上 \(m\) 这个限制,只需要多加一个元 \(y\),把初始的 OGF 改为
\[ \prod_{i=0}^{n-1}(1+x^iy) \]
  推导过程中把 \(y\) 当做一个常量,就像 \(x\) 一样。最后的式子为
\[ [y^m]\frac{1}{n}\sum_{d\mid n}{(1-{(-y)}^\frac{n}{d})}^d(\sum_{j\mid \frac{n}{d}}\mu(j)[\frac{n}{dj}\mid k]\frac{n}{dj})\\ \]
  还有一个问题就是要在二项式展开的时候计算组合数。直接分段打表就好了。

  时间复杂度:\(O(\sigma_0(n)^2)\)

代码

  阶乘的表删掉了。

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<assert.h>
using namespace std;
typedef long long ll;
const ll p=998244353;
ll fp(ll a,ll b)
{
    ll s=1;
    for(;b;b>>=1,a=a*a%p)
        if(b&1)
            s=s*a%p;
    return s;
}
int n,m,k;
ll getmiu(int x)
{
    ll s=1;
    for(int i=2;i*i<=x;i++)
        if(x%i==0)
        {
            int tmp=0;
            while(x%i==0)
            {
                tmp++;
                x/=i;
            }
            if(tmp>=2)
                return 0;
            s=-s;
        }
    if(x!=1)
        s=-s;
    return s;
}
int c[1];
const int N=10000000;
const int D=200000;
int fac[N+10];
ll factorial(int n)
{
    if(n<=N)
        return fac[n];
    ll s=c[n/D];
    for(int i=n/D*D+1;i<=n;i++)
        s=s*i%p;
    return s;
}
ll binom(int x,int y)
{
    return x>=y&&y>=0?factorial(x)*fp(factorial(y)*factorial(x-y)%p,p-2)%p:0;
}
void init()
{
    fac[0]=1;
    for(int i=1;i<=N;i++)
        fac[i]=(ll)fac[i-1]*i%p;
}
int gcd(int a,int b)
{
    return b?gcd(b,a%b):a;
}
int exgcd(int a,int b,int &x,int &y)
{
    if(!b)
    {
        x=1;
        y=0;
        return a;
    }
    int d=exgcd(b,a%b,y,x);
    y-=a/b*x;
    return d;
}
int get(int a)
{
    int x,y;
    int d=exgcd(a,p-1,x,y);
    if(x<0)
    {
        printf("%d %d\n%lld\n",x,y,(ll)x*a+y*(p-1));
        x+=(p-1)/d;
        y-=a/d;
        printf("%d %d\n%lld\n\n",x,y,(ll)x*a+y*(p-1));
    }
    return x;
}
ll mu[10000];
int d[10000];
int t=0;
int query(int x)
{
    return mu[lower_bound(d+1,d+t+1,x)-d];
}
int main()
{
#ifndef ONLINE_JUDGE
    freopen("c.in","r",stdin);
    freopen("c.out","w",stdout);
#endif
    init();
    scanf("%d%d%d",&n,&m,&k);
    for(int i=1;i*i<=n;i++)
        if(n%i==0)
        {
            d[++t]=i;
            if(i*i!=n)
                d[++t]=n/i;
        }
    sort(d+1,d+t+1);
    for(int i=1;i<=t;i++)
        mu[i]=getmiu(d[i]);
    ll ans=0;
    for(int i=1;i<=t;i++)
    {
        if(m%(n/d[i]))
            continue;
        int w=m/(n/d[i]);
        int v=(w&1?-((n/d[i])&1?-1:1):1);
        ll s1=binom(d[i],w);
        ll s2=0;
        for(int j=1;j<=t;j++)
            if((n/d[i])%d[j]==0)
                if(k%(n/d[i]/d[j])==0)
                    s2=(s2+mu[j]*n/d[i]/d[j])%p;
        ans=(ans+s1*s2*v)%p;
    }
    ans=ans*fp(n,p-2)%p;
    ans=(ans+p)%p;
    printf("%lld\n",ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/ywwyww/p/9279403.html
dft
今日推荐