洛谷 P3321 [SDOI2015]序列统计 dp+快速幂+任意模数fft

题目描述

小C有一个集合S,里面的元素都是小于M的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为N的数列,数列中的每个数都属于集合S。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:给定整数x,求所有可以生成出的,且满足数列中所有数的乘积mod M的值等于x的不同的数列的有多少个。小C认为,两个数列{Ai}和{Bi}不同,当且仅当至少存在一个整数i,满足Ai≠Bi。另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案mod 1004535809的值就可以了。

输入输出格式

输入格式:
一行,四个整数,N、M、x、|S|,其中|S|为集合S中元素个数。第二行,|S|个整数,表示集合S中的所有元素。

输出格式:
一行,一个整数,表示你求出的种类数mod 1004535809的值。

输入输出样例

输入样例#1:
4 3 1 2
1 2
输出样例#1:
8
说明

【样例说明】

可以生成的满足要求的不同的数列有(1,1,1,1)、(1,1,2,2)、(1,2,1,2)、(1,2,2,1)、(2,1,1,2)、(2,1,2,1)、(2,2,1,1)、(2,2,2,2)。

【数据规模和约定】

对于10%的数据,1<=N<=1000;

对于30%的数据,3<=M<=100;

对于60%的数据,3<=M<=800;

对于全部的数据,1<=N<=109,3<=M<=8000,M为质数,1<=x<=M-1,输入数据保证集合S中元素不重复

分析:
根据题意,容易列出dp式。

f [ x + y ] [ i ] = k l = i ( m o d   m ) f [ x ] [ k ] f [ y ] [ l ]

由于题目中的 x 不为 0 ,所以可以把 S 中的 0 去掉。然后我们发现,如果两边取 l o g ,那么就是一个卷积了。因为 m 是素数,所以先跑出原根,把每个数用离散对数表示,把 f [ i ] [ j ] 设为离散对数为 j 的方案数,式子就变成,
f [ x + y ] [ i ] = i = j + k f [ x ] [ j ] f [ y ] [ k ]

显然就是一个卷积形式,然后就可以用快速幂+ f f t 求解。当然, f [ i ] [ j ] 的项要加到 f [ i ] [ j   m o d   ( m 1 ) ]

代码:

#include <iostream>
#include <cstdio>
#include <cmath>
#define LL long long

const int maxn=33007;
const int p=1004535809;
const double pi=acos(-1);

using namespace std;

struct rec{
    double x,y;
};

rec operator +(rec a,rec b)
{
    return (rec){a.x+b.x,a.y+b.y};
}

rec operator -(rec a,rec b)
{
    return (rec){a.x-b.x,a.y-b.y};
}

rec operator *(rec a,rec b)
{
    return (rec){a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};
}

rec operator !(rec a)
{
    return (rec){a.x,-a.y};
}

LL n,m,s,t,g,cnt,len,x;
rec a[maxn],b[maxn],w[maxn],dfta[maxn],dftb[maxn],dftc[maxn],dftd[maxn];
LL f[maxn],f1[maxn],r[maxn];
LL prime[maxn],lg[maxn];

void divide(LL x)
{
    for (LL i=2;i<=trunc(sqrt(x));i++)
    {
        if (x%i==0)
        {
            prime[++cnt]=i;
            while (x%i==0) x/=i;
        }
    }
    if (x>1) prime[++cnt]=x;
}

LL power(LL x,LL y)
{
    if (y==1) return x;
    LL c=power(x,y/2);
    c=(c*c)%m;
    if (y%2) c=(c*x)%m;
    return c;
}

void getroot(LL x)
{
    bool flag;
    for (LL i=2;i<x;i++)
    {
        flag=0;
        for (LL j=1;j<=cnt;j++)
        {
            if (power(i,(x-1)/prime[j])==1)
            {
                flag=1;
                break;
            }
        }
        if (!flag)
        {
            g=i;
            return;
        }
    }
}

void fft(rec *a,LL f)
{
    for (LL i=0;i<len;i++)
    {
        if (i<r[i]) swap(a[i],a[r[i]]);
    }
    w[0]=(rec){1,0};
    for (LL i=2;i<=len;i*=2)
    {
        rec wn=(rec){cos(2*pi/i),f*sin(2*pi/i)};
        for (LL j=i/2;j>=0;j-=2) w[j]=w[j/2];
        for (LL j=1;j<i/2;j+=2) w[j]=w[j-1]*wn;
        for (LL j=0;j<len;j+=i)
        {
            for (LL k=0;k<i/2;k++)
            {
                rec u=a[j+k],v=a[j+k+i/2]*w[k];
                a[j+k]=u+v;
                a[j+k+i/2]=u-v;
            }
        }
    }
}

void init(LL len)
{
    LL k=trunc(log(len+0.5)/log(2));
    for (LL i=0;i<len;i++)
    {
        r[i]=(r[i>>1]>>1)|((i&1)<<(k-1));
    }
}

void FFT(LL *x,LL *y,LL *z,LL n,LL m)
{   
    len=1;
    while (len<(n+m-1)) len*=2;
    init(len);
    for (LL i=0;i<len;i++)
    {
        LL A,B;
        if (i<n) A=x[i]%p; else A=0;
        if (i<m) B=y[i]%p; else B=0;
        a[i]=(rec){A>>15,A&32767};
        b[i]=(rec){B>>15,B&32767};
    }
    fft(a,1); fft(b,1);
    for (LL i=0;i<len;i++)
    {
        LL j=(len-i)&(len-1);
        rec da,db,dc,dd;
        da=(a[i]+(!a[j]))*(rec){0.5,0};
        db=(a[i]-(!a[j]))*(rec){0,-0.5};
        dc=(b[i]+(!b[j]))*(rec){0.5,0};
        dd=(b[i]-(!b[j]))*(rec){0,-0.5};
        dfta[i]=da*dc;
        dftb[i]=da*dd;
        dftc[i]=db*dc;
        dftd[i]=db*dd;
    }
    for (LL i=0;i<len;i++)
    {
        a[i]=dfta[i]+dftb[i]*(rec){0,1};
        b[i]=dftc[i]+dftd[i]*(rec){0,1};
    }
    fft(a,-1); fft(b,-1);
    for (LL i=0;i<len;i++)
    {
        LL da,db,dc,dd;
        da=(LL)(a[i].x/len+0.5)%p;
        db=(LL)(a[i].y/len+0.5)%p;
        dc=(LL)(b[i].x/len+0.5)%p;
        dd=(LL)(b[i].y/len+0.5)%p;
        z[i]=((da<<30)%p+((db+dc)<<15)%p+dd)%p;
    }
}

void solve(LL n,LL m)
{
    if (n==1) return;
    solve(n/2,m);
    FFT(f,f,f,m,m);
    for (LL i=m;i<=m+m;i++) f[i%m]=(f[i%m]+f[i])%p;
    if (n%2)
    {
        FFT(f,f1,f,m,m);
        for (LL i=m;i<=m+m;i++) f[i%m]=(f[i%m]+f[i])%p;
    }
}

int main()
{
    scanf("%lld%lld%lld%lld",&n,&m,&t,&s);
    divide(m-1);
    getroot(m);
    LL k=1;
    lg[1]=0;
    for (LL i=1;i<=m-2;i++)
    {
        k=(k*g)%m;
        lg[k]=i;
    }
    for (LL i=1;i<=s;i++)
    {
        scanf("%lld",&x);
        x%=m;
        if (x!=0)
        {
            f[lg[x]]++;
            f1[lg[x]]++;
        }
    }   
    solve(n,m-1);
    printf("%lld",f[lg[t]]);
}

猜你喜欢

转载自blog.csdn.net/liangzihao1/article/details/81570946