[bzoj4816][Sdoi2017]数字表格 (反演+逆元)

(真不想做莫比乌斯了)

首先根据题意写出式子

i=1~nj=1~m)f[gcd(i,j)]

很明显的f可以预处理出来,解决

根据套路分析,我们可以先枚举gcd(i,j)==d

(d=1~n)f[d]......后面该怎么写?

我们发现前面式子中i,j为连乘,而对于相同的gcd,就可以变成f[d]的几次幂!

则∏(d=1~n)f[d]Σ(i=1~n/d)Σ(j=1~m/d)[gcd(i,j)==1]

然后就可以开心的反演了

(d=1~n)f[d]Σ(i=1~n/d)Σ(j=1~m/d)[gcd(i,j)==1]

=∏(d=1~n)f[d]Σ(i=1~n/d)Σ(j=1~m/d)Σ(k|i&&k|j)μ(k)

扫描二维码关注公众号,回复: 6447002 查看本文章

(接下来,我们先枚举k)

=∏(d=1~n)f[d]Σ(k=1~n)μ[k](n/kd)(m/kd)

(先枚举kd=D)

=∏(D=1~n)(d|D)f[d]μ[D/d](n/D)(m/D)

=∏(D=1~n)(∏(d|D)f[d]μ[D/d])(n/D)(m/D)

至此反演结束

再来观察这个式子,我们发现∏(d|D)f[d]μ[D/d]是关于D的一个函数,我们可以把它的前缀积处理出来,复杂度O(n*log(n))

处理过程中,当μ[D/d]==-1时需要除法,所以需要求逆元,而对于1e9+7这个素数,f[i]对于1e9+7的逆元为pow(f[i],mod-2)

在求解时我们需要取一段的前缀积,所以还需要把前缀积的逆元处理出来,方法同上

逆元处理复杂度O(n*log(n))

在求解时结合数论分块和快速幂,复杂度O(T*sqrt(n)*log(n))

总复杂度O(n*log(n)+T*sqrt(n)*log(n))

这道题做的时候主要卡在把变成Σ并变成指数,在此做个标记

后续有优化可以把求逆元的复杂度干掉

具体参考这篇博客

 

AC代码

#include<cstdio>
#include<iostream>
#define ll long long
#define re register
const int mod=1e9+7;
using namespace std;
int p[500010],top;bool v[1000010];short mu[1000010];ll f[1000010],ni[1000010],tot[1000010];
inline ll pow(ll a,ll b){
    re ll ans=1;
    for(;b;b>>=1){
        if(b&1) (ans*=a)%=mod;
        (a*=a)%=mod;
    }
    return ans;
}
int main(){
    mu[1]=1;f[1]=1;ni[1]=1;tot[1]=1;
    for(int i=2;i<=1000000;i++)
      f[i]=(f[i-1]+f[i-2])%mod,ni[i]=pow(f[i],mod-2),tot[i]=1;
    for(int i=2;i<=1000000;i++){
        if(!v[i]){
            p[++top]=i;
            mu[i]=-1;
        }
        for(int j=1;j<=top&&p[j]*i<=1000000;j++){
            v[i*p[j]]=1;
            if(!(i%p[j])) break;
            mu[i*p[j]]=-mu[i];
        }
    }
    for(int i=1;i<=1000000;i++){
        for(re int j=1;j*i<=1000000;j++)
          if(mu[j]==-1) (tot[j*i]*=ni[i])%=mod;
          else if(mu[j]==1) (tot[j*i]*=f[i])%=mod;
    }
    tot[0]=1;
    for(int i=1;i<=1000000;i++) (tot[i]*=tot[i-1])%=mod;
    ni[0]=1;
    for(re int i=1;i<=1000000;i++)
      ni[i]=pow(tot[i],mod-2);
    re int t,n,m,x;
    re ll ans;
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&m);ans=1;
        if(n>m) swap(n,m);
        for(int i=1;i<=n;i=x+1){
            x=min((n/(n/i)),(m/(m/i)));
            (ans*=pow(tot[x]*ni[i-1]%mod,1ll*(n/i)*(m/i)))%=mod;
        }
        printf("%lld\n",ans);
    }
    return 0;
}
View Code

猜你喜欢

转载自www.cnblogs.com/mikufun-hzoi-cpp/p/11013959.html