SDOI2015 序列统计

Description

小C有一个集合\(S\),里面的元素都是小于\(M\)的非负整数。他用程序编写了一个数列生成器,可以生成一个长度为\(N\)的数列,

数列中的每个数都属于集合\(S\)。小C用这个生成器生成了许多这样的数列。但是小C有一个问题需要你的帮助:

给定整数\(x\),求所有可以生成出的,且满足数列中所有数的乘积\(\mod M\)的值等于\(x\)的不同的数列的有多少个。

\(C\)认为,两个数列\(\{Ai\}\)\(\{Bi\}\)不同,当且仅当至少存在一个整数\(i\),满足\(Ai\neq Bi\)

另外,小C认为这个问题的答案可能很大,因此他只需要你帮助他求出答案\(\mod 1004535809\)的值就可以了。

Input

一行,四个整数,\(N、M、x、|S|\),其中\(|S|\)为集合\(S\)中元素个数。

第二行,\(|S|\)个整数,表示集合\(S\)中的所有元素。

\(1 \leq N \leq 10^9,3 \leq M \leq 8000\),M为质数

\(0 \leq x \leq M-1\),输入数据保证集合S中元素不重复\(x \in [1,m-1]\)

集合中的数$ \in [0,m-1]$

Output

一行,一个整数,表示你求出的种类数\(\mod 1004535809\)的值。

Solution

看到这题。。首先很容易列出一个DP转移方程

\[ F[i][j]=\sum_{a+b \equiv j (\mod m) } {F[i-1][a]*F[i-1][b]} \]

我们发现它非常不优美,复杂度高达$ O (n * m^2) $

我们发现这个式子可以倍增。。于是很轻松的干掉一个n,它的复杂度变成了$ O(\log n * m^2) $

这貌似还是有点多。。考虑如何干掉一个 $ m $

咦。。这个模数貌似有点熟悉。。考虑NTT

不过这是乘法。。我们做不了NTT 。。。

考虑原根

\(p\)\(m\)的原根。。那么\(p\)的幂次可以表示出\([1,m)\)的所有数字————原根定义

于是DP方程变成了这样

\[ F[i][j]=\sum_{g^a + g^b \equiv g^j (\mod m)}{F[i-1][a]*F[i-1][b]} \]

注意。。此时\(F[i][j]\)表示选到第\(i\)个数,大小为\(p^j\)次的方案数

再一变

\[ F[i][j]=\sum_{a+b \equiv j (\mod m-1)}{F[i-1][a]*F[i-1][b]} \]

我们发现这玩意长得像个卷积。。可以用NTT了

于是复杂度变成了 $ O(m \log n log m)$

Code

#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
int n,m,x,lens,g[2000000],a[2000010],b[2000010];
int f[2000010],ans[2000010];
int fpow(int x,int k,int Mod)
{
    int ans=1;
    while (k)
    {
        if (k&1) ans=1LL*ans*x%Mod;
        x=1LL*x*x%Mod;
        k>>=1;
    }
    return ans;
}
namespace GetRoot //求原根
{
    int prime[1000000],cnt;
    bool check(int x,int p)
    {
        for (int i=1;i<=cnt;i++)
            if (fpow(x,(p-1)/(prime[i]),p)==1) return 0;
        return 1;
    }
    int find(int p)
    {
        int x=p-1;
        for (int i=2;i*i<=x;i++)
        {
            if (x%i==0)
            {
                prime[++cnt]=i;
                while (x%i==0) x/=i; 
            }
        }
        if (x!=1) prime[++cnt]=x;
        for (int i=2;;i++)
            if (check(i,p)) return i;
    }
}
namespace NTT
{
    const int Mod=1004535809,p=3;
    int n=1;
    void NTT(int *a,int inv)
    {
        int lim=0;
        while ((1<<lim)<n) lim++;
        for (int i=0;i<n;i++)
        {
            int t=0;
            for (int j=0;j<lim;j++)
                if ((i>>j) & 1) t|=1<<(lim-j-1);
            if (i<t) swap(a[i],a[t]);
        }
        for (int l=2;l<=n;l*=2)
        {
            int m=l/2,p0=fpow(inv?fpow(p,Mod-2,Mod):p,(Mod-1)/l,Mod);
            for (int *buf=a;buf!=a+n;buf+=l)
            {
                int pn=1;
                for (int i=0;i<m;i++)
                {
                    int t=1LL*pn*buf[i+m]%Mod;
                    buf[i+m]=(buf[i]-t+Mod)%Mod;
                    buf[i]=(buf[i]+t)%Mod;
                    pn=1LL*pn*p0%Mod;
                }
            }
        }
}
    void Union(int *a,int *c,int len)
    {
        while (n<2*len) n<<=1;
        for (int i=0;i<n;i++) b[i]=0;
        for (int i=0;i<len;i++) b[i]=c[i];
        NTT(a,0);NTT(b,0);
        for (int i=0;i<n;i++) a[i]=1LL*a[i]*b[i]%Mod;
        NTT(a,1);
        int invn=fpow(n,Mod-2,Mod);
        for (int i=0;i<n;i++) a[i]=1LL*a[i]*invn%Mod;
        for (int i=len-1;i<n;i++) a[i%(len-1)]=(a[i%(len-1)]+a[i])%Mod,a[i]=0;
    }
}
void init()
{
    int t=GetRoot::find(m);
    for (int i=0,k=1;i<m-1;i++,k=1LL*k*t%m) g[k]=i;
    x=g[x];
    for (int i=1;i<=lens;i++)
        if (a[i]) f[g[a[i]]]++;  //若a[i]=0.就直接舍弃。
}
void solve() //倍增优化
{
    int k=n;
    ans[0]=1;
    while (k)
    {
        if (k&1) NTT::Union(ans,f,m);
        NTT::Union(f,f,m);
        k>>=1;
    }
    printf("%d\n",ans[x]);
}
int main()
{
    scanf("%d%d%d%d",&n,&m,&x,&lens);
    for (int i=1;i<=lens;i++) scanf("%d",&a[i]);
    init();
    solve();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Code-Geass/p/10261389.html