【[NOI2013]矩阵游戏】

我们看到了及其可怕的数据范围

这个样子都没有办法直接读入的数据范围应该怎么算

我们观察一下递推式\(f[i][j]=a*f[i][j]+b(j!=1)\)

\(f[i][1]=c*f[i-1][m]+d\)

转移非常简单,于是可以考虑一下矩阵乘法

如果我们将这个矩阵破坏成一个链,那么就会有这种形式的递推

连续推\(m\)次第一个柿子,之后再推一次第二个柿子,之后反复

重复上面的过程\(n\)次就好了

于是我们可以将连续转移\(m\)次一式的到的矩阵和第二个式子的转移矩阵乘起来,之后将这个矩阵再转移\(n\)次就是答案了

由于\(n,m\)是在太大了,我们又发现模数是质数,于是我们可以利用费马小定理来降幂

则有

\[A^{mod-1}\%mod=1\]

所以

\[A^m\equiv A^{m\%(mod-1)}(\%\ mod )\]

至于为什么要特判\(a=1\)

我怎么知道啊

#include<iostream>
#include<cstring>
#include<cstdio>
#define re register
#define maxn 5
#define LL long long
const LL mod=1000000007;
char N_[1000005],M_[1000005];
LL n,m,a,b,c,d;
inline LL read()
{
    char c=getchar();
    LL x=0;
    while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9')
      x=x*10%(mod-1)+c-48,c=getchar();
    return x;
}
struct Mat
{
    LL a[4][4],ans[4][4];
    inline void did_ans()
    {
        LL mid[4][4];
        for(re int i=1;i<=2;i++)
            for(re int j=1;j<=2;j++)
                mid[i][j]=ans[i][j],ans[i][j]=0;
        for(re int i=1;i<=2;i++)
            for(re int j=1;j<=2;j++)
                for(re int p=1;p<=2;p++)
                    ans[i][j]=(ans[i][j]+(a[i][p]*mid[p][j])%mod)%mod;
    }
    inline void did_a()
    {
        LL mid[4][4];
        for(re int i=1;i<=2;i++)
            for(re int j=1;j<=2;j++)
                mid[i][j]=a[i][j],a[i][j]=0;
        for(re int i=1;i<=2;i++)
            for(re int j=1;j<=2;j++)
                for(re int p=1;p<=2;p++)
                    a[i][j]=(a[i][j]+(mid[i][p]*mid[p][j])%mod)%mod;
    }
}R,L,K;
inline void mul(Mat &A,Mat &B)
{
    LL mid[4][4];
    for(re int i=1;i<=2;i++)
        for(re int j=1;j<=2;j++)
            mid[i][j]=A.ans[i][j],A.ans[i][j]=0;
    for(re int i=1;i<=2;i++)
        for(re int j=1;j<=2;j++)
            for(re int p=1;p<=2;p++)
                A.ans[i][j]=(A.ans[i][j]+(mid[i][p]*B.ans[p][j])%mod)%mod;
    for(re int i=1;i<=2;i++)
        for(re int j=1;j<=2;j++)
            A.a[i][j]=A.ans[i][j];
}
int main()
{
    scanf("%s%s",N_+1,M_+1);
    a=read(),b=read(),c=read(),d=read();
    int lenn=strlen(N_+1);
    if(a==1) 
    {
        for(re int i=1;i<=lenn;i++) n=n*10%mod+N_[i]-48;
        n=(n-2+mod)%mod;
    }
    else 
    {
        for(re int i=1;i<=lenn;i++) n=n*10%(mod-1)+N_[i]-48;
        n=(n-2+mod-1)%(mod-1);
    }
    lenn=strlen(M_+1);
    if(c==1) 
    {
        for(re int i=1;i<=lenn;i++) m=m*10%mod+M_[i]-48;
        m=(m-2+mod)%mod;
    }
    else 
    {
        for(re int i=1;i<=lenn;i++) m=m*10%(mod-1)+M_[i]-48;
        m=(m-2+mod-1)%(mod-1);
    }
    R.a[1][1]=R.ans[1][1]=1;
    R.a[2][1]=R.ans[2][1]=b;
    R.ans[2][2]=R.a[2][2]=a;
    R.ans[1][2]=R.a[1][2]=0;
    K=R;
    while(m)
    {
        if(m&1) R.did_ans();
        m>>=1;
        R.did_a();
    }
    L.a[1][1]=L.ans[1][1]=1;
    L.a[2][1]=L.ans[2][1]=d;
    L.a[2][2]=L.ans[2][2]=c;
    L.ans[1][2]=L.a[1][2]=0;
    mul(L,R);
    while(n)
    {
        if(n&1) L.did_ans();
        n>>=1;
        L.did_a();
    }
    printf("%lld",((L.ans[2][1]+L.ans[2][2])%mod*R.ans[2][2]%mod+R.ans[2][1])%mod);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/asuldb/p/10207869.html