#loj3090 [BJOI2019] 勘破神机

简单线性代数练习题
首先翻开具体数学生成函数一章,可以发现$F(n),G(n)$满足以下递推式

$$F(n)=F(n-1)+F(n-2),F(0)=1,F(1)=1$$
$$G(n)=4G(n-2)-G(n-4),G(2)=3,G(0)=1$$

我们发现$G$只有偶数项有值(证明的话可以把网格黑白染色什么的来证明)

那么我们可以设$T(n)=G(2n)$

那么$T$服从以下递推式

$$T(n)=4T(n-1)-T(n-2),T(1)=1,T(0)=1$$

那么我们发现$F$和$G$是两个二阶常系数线性递推数列

根据特征方程和生成函数那一套理论,我们可以知道$F,T$都有这样的通项公式

$$C×A^n+B×D^n$$

考虑题目让我们算什么

$$\sum_{i=1}^{n}{F(i) \choose k}$$

众所周知,组合数是个$O(k)$阶的多项式,那么我们现在的问题就变成了快速求

$$\sum_{i=1}^{n}F(i)^k$$

那么稍微画一下式子就是

$$\sum_{i=1}^{n}(CA^i+DB^i)^k$$

暴力二项式定理展开

$$\sum_{i=1}^{n}\sum_{j=0}^{k}C^{j}A^{ij}D^{k-j}B^{i(k-j)}$$

稍微交换一下求和号

$$\sum_{j=0}^{k}C^{j}D^{k-j}\sum_{i=1}^{n}(A^{j}B^{k-j})^i$$

后面显然是个等比数列求和,可以$O(logn)$的算出

那么这题就做完了,复杂度$O(Tk^2logn)$

注意当$m=3$的时候,你需要将输入的区间缩一下,因为询问的是$G$但是我们的做法只能处理$T$

另外你可能需要将输入的区间加加减减,因为我们的公式一般认为$F(0)=1$

下面是简单的线性代数内容,是关于如何解出$A,B,C,D$的

对于F显然是fib数列的通项公式,你应该熟练的背过它

$$A=\frac{\sqrt{5}+1}{2}$$

$$B=\frac{\sqrt{5}-1}{2}$$

$$C=\frac{\sqrt{5}}{5}$$

$$D=-\frac{\sqrt{5}}{5}$$

对于$T$稍微解一下特征方程可以得到

$$A=2+\sqrt{3}$$

$$B=2-\sqrt{3}$$

那么我们令$T(1)=1,T(2)=1$,就可以通过待定系数法解出$C,D$来

可能涉及到$\sqrt{3}$的方程组会有点复杂,推荐的解法是

令$x=C+D,y=\sqrt{3}(C-D)$

然后你会发现方程组干净了很多,从而我们可以解出x和y来,借助x和y就可以还原C,D了

解得

$$C=\frac{3+\sqrt{3}}{6},D=\frac{3-sqrt{3}}{6}$$

最后一个问题是5和3在膜998244353剩余系下统统没有二次剩余

把每个数字表示成$r+v\sqrt{k}$的形式即可实现模意义复数,然后就能计算了

等比数列求和涉及除法,使用这种复数实现除法的时候用分母有理化推个式子就行了

// luogu-judger-enable-o2
#include<cstdio>
#include<algorithm>
using namespace std;const int N=1e5+10;
typedef unsigned long long ll;const ll mod=998244353;
inline ll po(ll a,ll p){ll r=1;for(;p;p>>=1,a=a*a%mod)if(p&1)r=r*a%mod;return r;}
# define md(x) (x=(x>=mod)?x-mod:x)
ll sqr;
struct cmp
{
    ll r;ll v;
    cmp(ll R=0,ll V=0){r=R;v=V;}
    friend cmp operator +(cmp a,cmp b)
    {
        cmp c=cmp(a.r+b.r,a.v+b.v);
        md(c.r);md(c.v);return c;
    }
    friend cmp operator -(cmp a,cmp b)
    {
        cmp c=cmp(a.r+mod-b.r,a.v+mod-b.v);
        md(c.r);md(c.v);return c;
    }
    friend cmp operator *(cmp a,cmp b)
    {
        return cmp((a.r*b.r+sqr*a.v%mod*b.v)%mod,(a.r*b.v+b.r*a.v)%mod);
    }
    friend cmp operator /(cmp a,cmp b)
    {
        cmp c=a*cmp(b.r,mod-b.v);
        ll iv=((b.r*b.r+(mod-sqr)*b.v%mod*b.v))%mod;
        iv=po(iv,mod-2);
        (c.r*=iv)%=mod;(c.v*=iv)%=mod;
        return c;
    }
}A,B,C,D;int T;int m;
// fib(n) =(A^n-B^n)*C
inline cmp cpo(cmp a,ll p)
{
    cmp r=cmp(1,0);
    for(;p;p>>=1,a=a*a)if(p&1)r=r*a;return r;
}
ll ifac[N];ll c[1010][1010];ll sr1[1010][1010];
inline void pre()
{   
    ifac[0]=1;ifac[1]=1;
    for(int i=2;i<N;i++)
        ifac[i]=(mod-mod/i)*ifac[mod%i]%mod;
    for(int i=1;i<N;i++)
        (ifac[i]*=ifac[i-1])%=mod;
    for(int i=0;i<1006;i++)
    {
        c[i][0]=c[i][i]=1;
        for(int j=1;j<i;j++)
            c[i][j]=(c[i-1][j-1]+c[i-1][j])%mod;
    }
    sr1[0][0]=1;
    for(int i=0;i<1006;i++)
        for(int j=1;j<=i;j++)
            sr1[i][j]=(sr1[i-1][j-1]+(mod-i+1)*sr1[i-1][j])%mod;    
    if(m==2)
    {   
        ll iv2=(mod+1)/2;sqr=5;
        A=cmp(iv2,iv2);B=cmp(iv2,mod-iv2);
        C=cmp(0,po(5,mod-2));
        D=cmp(C.r,mod-C.v);
    }else 
    {
        ll iv6=po(6,mod-2);sqr=3;
        A=cmp(2,1);B=cmp(2,mod-1);
        C=cmp(3*iv6%mod,(mod-1)*iv6%mod);
        D=cmp(C.r,mod-C.v); 
    } 
}
//mar <cmp> st;mar <cmp> trs;
inline cmp csum(cmp q,ll n)
{
    if(q.r==1&&q.v==0)
        return (cmp){(n+1)%mod,0};
    cmp res=cpo(q,n+1)-cmp(1,0);
    res=res/(q-cmp(1,0));
//  printf("q=%lld,%lld,res=%lld,%lld\n",q.r,q.v,res.r,res.v);
    return res;
}
inline ll cfibk(ll n,int k)
{
    cmp res=cmp(0,0);
    for(int j=0;j<=k;j++)
    {
        cmp q=cpo(A,k-j)*cpo(B,j);
        cmp xs=(cmp){c[k][j],0}*cpo(C,k-j)*cpo(D,j);
        res=res+xs*csum(q,n);
    }
    return res.r;
}
namespace solver2
{
    inline void solve()
    {
        ll l;ll r;int k;
        scanf("%lld%lld%d",&l,&r,&k);
        l++;r++;ll res=0;
        for(int i=1;i<=k;i++)
            (res+=(cfibk(r,i)+mod-cfibk(l-1,i))*sr1[k][i])%=mod;
        (res*=ifac[k])%=mod;
        (res*=po((r-l+1)%mod,mod-2))%=mod;
        printf("%lld\n",res);
    }
}
namespace solver3
{
    inline void solve()
    {
        ll l;ll r;int k;
        scanf("%lld%lld%d",&l,&r,&k);
        ll len=r-l+1;
        if(l&1)l++;l/=2;
        if(r&1)r--;r/=2;
        l++;r++;
        ll res=0;
        for(int i=1;i<=k;i++)
            (res+=(cfibk(r,i)+mod-cfibk(l-1,i))*sr1[k][i])%=mod;
        (res*=ifac[k])%=mod;
        (res*=po(len%mod,mod-2))%=mod;
        printf("%lld\n",res);
    }
}
int main()
{
    scanf("%d%d",&T,&m);
    pre();
    if(m==2)
        for(int z=1;z<=T;z++)
            solver2::solve();
    else 
        for(int z=1;z<=T;z++)
            solver3::solve();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/sweetphoenix/p/10786178.html
今日推荐