#loj 3058 [HNOI2019] 白兔之舞

单位根反演思博题

模数是乱给的记得整个任意模数ntt

k为p-1的约数意味着一定存在k次单位根,设g是p的原根则\(w_{k}^{1}=g^{\frac{k-1}{p}}\)

既然k次单位根存在自然考虑单位根反演了

\(f(i)\)表示跳了i步并且停在了第二维为y的顶点的方案数

\(st\)表示初始向量而\(A\)表示转移矩阵,那么有

\[f(i)={n \choose i}(st×A^i)(y)\]

那么可以构造数列\(f\)的一般生成函数\(F(z)=\sum_{k}f(k)z^k\)

现在先让我们考虑如何求出\(F(z)\)中恰好是k的倍数的项之和,即求出

\[\sum_{i}[i|k]{n \choose i}(st×A^i)(y)\]

\[\sum_{i}\sum_{j}\omega_{k}^{ij}{n \choose i}(st×A^i)(y)\]

\[\sum_{i}\sum_{j}\omega_{k}^{ij} {n \choose i}(st×A^i)(y)\]

\[\sum_{j}\sum_{i} {n \choose i}(\omega_{k}^{ij}A^i)(y)\]

\[\sum_{j}(st×(\sum_{i} {n \choose i}(\omega_{k}^{j}A)^i))(y)\]

\[\sum_{j}(st×(1+\omega_{k}^{j}A)^n)(y)\]

看起来就很好算了吧

啊,如果现在要求模k是某个特定的值怎么办呢?

很简单啊,把f数列整体平移若干个单位就可以了

为了方便起见设\(g(j)=(st×(1+\omega_{k}^{j}A)^n)(y)\)

如果我们现在处理的生成函数是\(F(z)z^t\)的话,我们的式子应该长这样

\[\sum_{j}g(j)\omega_{k}^{tj}\]

那么如何对于每一个t处理出上面的东西呢?

教练我会多点求值!

大常数\(O(nlog^2n)\)肯定是凉的

这其实是任意长度fft,有一个\(O(nlogn)\)的优秀做法

具体来讲借助了这个恒等式

\[tj={t+j \choose 2}-{t \choose 2}-{j \choose 2}\]

那么现在我们要求的东西就变成了

\[\omega_{k}^{-{t \choose 2}}\sum_{j}\frac{g(j)}{\omega_{k}^{{j \choose 2}}}\omega_{k}^{{t+j \choose 2}}\]

\(p(n)=\frac{g(n)}{q(n)},q(n)=\omega_{k}^{{n \choose 2}}\)

则我们算的就是

\[\omega_{k}^{-{t \choose 2}}\sum_{i-j=t}g(j)q(i)\]

显然是个差卷积,随便fft几下就行了

#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;const int N=262144+10;const int D=18;
typedef long long ll;typedef double db;const db pi=acos(-1.0);
const int P=32768;const int SF=15;const int msk=32767;ll PP;ll mod;
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
namespace poly
{
    struct cmp
    {
        db r;db v;
        friend cmp operator +(cmp a,cmp b){return (cmp){a.r+b.r,a.v+b.v};}
        friend cmp operator -(cmp a,cmp b){return (cmp){a.r-b.r,a.v-b.v};}
        friend cmp operator *(cmp a,cmp b){return (cmp){a.r*b.r-a.v*b.v,a.r*b.v+a.v*b.r};}
        void operator /=(const db& b){r/=b;v/=b;}
    }tr[N],tr1[N],tr2[N],tr3[N],tr4[N],rt[2][20][N];
    int rv[20][N];ll m13[N],m24[N],m14[N],m23[N];
    inline void pre()
    {
        for(int d=1;d<=D;d++)
            for(int i=0;i<(1<<d);i++)rv[d][i]=(rv[d][i>>1]>>1)|((i&1)<<(d-1));
        for(int d=1,t=1;d<=D;d++,t<<=1)
            for(int i=0;i<(1<<d);i++)rt[0][d][i]=(cmp){cos(i*pi/t),sin(i*pi/t)};
        for(int d=1,t=1;d<=D;d++,t<<=1)
            for(int i=0;i<(1<<d);i++)rt[1][d][i]=(cmp){cos(i*pi/t),-sin(i*pi/t)};
        PP=(ll)P*P%mod;
    }
    inline void fft(cmp* a,int len,int d,int o)
    {
        for(int i=0;i<len;i++)if(i<rv[d][i])swap(a[i],a[rv[d][i]]);
        int i;cmp * w;
        for(int k=1,j=1;k<len;k<<=1,j++)
            for(int s=0;s<len;s+=(k<<1))
                for(i=s,w=rt[o][j];i<s+k;++i,++w)
                    {cmp a1=a[i+k]*(*w);a[i+k]=a[i]-a1;a[i]=a[i]+a1;}
        if(o)for(int i=0;i<len;i++)a[i]/=len;
    }
    inline void dbdft(ll* a,int len,int d,cmp* op1,cmp* op2)
    {
        for(int i=0;i<len;i++)tr[i]=(cmp){(db)(a[i]>>SF),(db)(a[i]&msk)};
        fft(tr,len,d,0);tr[len]=tr[0];
        for(cmp* p1=tr,*p2=tr+len,*p3=op1;p1!=tr+len;++p1,--p2,++p3)
            (*p3)=(cmp){p1->r+p2->r,p1->v-p2->v}*(cmp){0.5,0};
        for(cmp *p1=tr,*p2=tr+len,*p3=op2;p1!=tr+len;++p1,--p2,++p3)
            (*p3)=(cmp){p1->r-p2->r,p1->v+p2->v}*(cmp){0,-0.5};
    }
    inline void dbidft(cmp* a,int len,int d,ll* op1,ll* op2)
    {
        fft(a,len,d,1);
        for(int i=0;i<len;i++)op1[i]=(ll)(a[i].r+0.5)%mod;
        for(int i=0;i<len;i++)op2[i]=(ll)(a[i].v+0.5)%mod;
    }
    cmp tst[N];
    inline void mul(ll* a,ll* b,ll* c,int len,int d)
    {
        dbdft(a,len,d,tr1,tr2);dbdft(b,len,d,tr3,tr4);
        for(int i=0;i<len;i++)
            tr[i]=tr1[i]*tr3[i]+(cmp){0,1}*tr2[i]*tr4[i];
        dbidft(tr,len,d,m13,m24);
        for(int i=0;i<len;i++)
            tr[i]=tr1[i]*tr4[i]+(cmp){0,1}*tr2[i]*tr3[i];
        dbidft(tr,len,d,m14,m23);
        for(int i=0;i<len;i++)
            c[i]=(m13[i]*PP+(m14[i]+m23[i])*P+m24[i])%mod;
    }
}
namespace calcg
{
    int zhi[N];int ct;int nu[N];int divs[N];int hd;
    inline bool ck(int g)
    {
        for(int i=1;i<=hd;i++)
            if(divs[i]!=mod-1&&po(g,divs[i])==1)return false;
        return true;
    }
    inline void dfs(int cur,int nw)
    {
        if(cur==ct+1){divs[++hd]=nw;return;}
        for(int i=0;i<=nu[cur];i++,nw*=zhi[cur])
            dfs(cur+1,nw);
    }
    inline int solve()
    {
        ll phi=mod-1;
        for(ll i=2;i*i<=phi;i++)
            if(phi%i==0)
            {
                zhi[++ct]=i;
                while(phi%i==0)nu[ct]++,phi/=i;
            }
        if(phi!=1)zhi[++ct]=phi,nu[ct]=1;
        dfs(1,1);
        for(int g=2;g<=mod-1;g++)
            if(ck(g))return g;
        return -1;
    }
}
int S;
struct mar
{
    ll mp[4][4];
    inline ll * operator [](const int& x){return mp[x];}
    mar()
    {
        for(int i=0;i<4;i++)
            for(int j=0;j<4;j++)mp[i][j]=0;
    }
    friend mar operator *(mar a,mar b)
    {
        mar c;
        for(int i=1;i<=S;i++)
            for(int k=1;k<=S;k++)
                for(int j=1;j<=S;j++)
                    (c[i][j]+=a[i][k]*b[k][j])%=mod;
        return c;
    }
}st,trs,ori;ll gen;ll omg;
ll f[N];ll sw[N];int n;int k;int l;int x;int y;
ll res[N];ll ans[N];
inline ll ctwo(ll n){return ((ll)n*(n-1)/2)%(mod-1);}
int main()
{
    scanf("%d%d%d%d%d%lld",&n,&k,&l,&x,&y,&mod);S=n;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=n;j++)
            scanf("%lld",&ori[i][j]);
    poly::pre();gen=calcg::solve();
    omg=po(gen,(mod-1)/k);
    for(int i=0;i<=2*k;i++)sw[i]=po(omg,ctwo(i));
    for(int t=0;t<k;t++)
    {
        st=mar();st[1][x]=1;ll wkt=po(omg,t);
        for(int i=1;i<=n;i++)
            for(int j=1;j<=n;j++)
                trs[i][j]=ori[i][j]*wkt%mod;
        for(int i=1;i<=n;i++)
                (trs[i][i]+=1)%=mod;
        for(int p=l;p;p>>=1,trs=trs*trs)if(p&1)st=st*trs;
        f[t]=st[1][y];
    }
    for(int t=0;t<k;t++)
        (f[t]*=po(sw[t],mod-2))%=mod;
    for(int t=0;t<k;t++)
        if(t<(k-t))swap(f[t],f[k-t]);
    int len=1;int d=0;
    while(len<k+k)len<<=1,d++;
    poly::mul(f,sw,res,len,d);
    for(int i=0;i<k;i++)
        (ans[i]=res[i+k]*po(k*sw[i]%mod,mod-2))%=mod;
    ans[k]=ans[0];
    for(int i=k;i>=1;i--)
        printf("%lld\n",ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/sweetphoenix/p/10833517.html