题目大意:给定正整数nn,求有多少个整数四元组
a,b,c,d∈[0,n−1]满足
ab=cd(modn)。由于
n非常大,将以质因数分解的形式给出
n=∏i=1mpici。
m≤5×105,p,c≤109。
题解:由中国剩余定理的结论知我们只需要求出
n=pc的答案然后乘起来即可。
考虑这个怎么做,不妨设
cnti=∑(a,b)[abmodn=i],那么答案就是
∑i=0n−1cnti2。显然
cnt0=n2−∑i=1n−1cnti,考虑某个
cnti(i>0)怎么算。
不妨设
i=pkq,满足
gcd(p,q)=1。显然
0≤k<c,q>0。
那么
ab=i,等价于
a=pk′a′,b=pk−k′b′,gcd(a′,p)=gcd(b′,p)=1,a′b′=q(modpc−k),并且每求出这样的一组
(a′,b′,k′),都会有
pk′×pk−k′=pk组
(a,b,k)与之对应,而
(a′,b′,k′)的组数显然就是
ϕ(pc−k),与
k′无关,因此对于每个
k′,答案就是
pkϕ(pc−k),因此
cnti=(k+1)pkϕ(pc−k)=(k+1)ϕ(pc)。
然后考虑对于每个
k有多少个
i,显然就是
ϕ(pc−k)
首先考虑
∑i=1n−1cnti2=∑k=0c−1(k+1)2ϕ2(pc)ϕ(pc−k)=ϕ2(pc)(p−1)pc∑k=1ck2(p1)k
后面那个怎么求:
S2=k=1∑nk2qkS2−n2qn=k=2∑n(k−1)2qk−1qS2−n2qn+1=k=2∑n(k2−2k+1)qkqS2−n2qn+1=k=2∑nk2qk−2k=2∑nkqk+k=2∑nqkqS2−n2qn+1=S2−q+2(S1−q)+S0−qS2=q−1n2qn+1−q+2(S1−q)+S0−q
其中
S1=∑k=1nkqk,S0=∑k=1nqk,求法同理。
剩余还有一个
cnt02,过程类似,略。
#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define Rep(i,v) rep(i,0,(int)v.size()-1)
#define lint long long
#define mod 1000000007
#define ull unsigned lint
#define db long double
#define pb push_back
#define mp make_pair
#define fir first
#define sec second
#define gc getchar()
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
typedef pair<int,int> pii;
typedef set<int>::iterator sit;
inline int inn()
{
int x,ch;while((ch=gc)<'0'||ch>'9');
x=ch^'0';while((ch=gc)>='0'&&ch<='9')
x=(x<<1)+(x<<3)+(ch^'0');return x;
}
inline int fast_pow(int x,int k,int ans=1) { for(;k;k>>=1,x=(lint)x*x%mod) (k&1)?ans=(lint)ans*x%mod:0;return ans; }
inline int inv(int x) { return fast_pow(x,mod-2); }
inline lint squ(int x) { return (lint)x*x; }
inline int solve(int p,int c)
{
int n=fast_pow(p,c),t=n-fast_pow(p,c-1);if(t<0) t+=mod;
int q=inv(p),v=fast_pow(q,c),z=inv(q-1),c0=q*(v-1ll)%mod*z%mod;
int c1=((lint)c*v%mod*q-c0)%mod*z%mod;if(c1<0) c1+=mod;
int c2=((lint)c*c%mod*v%mod*q%mod-q-2ll*(c1-q)%mod+(lint)q*(c0-v)%mod)*z%mod;if(c2<0) c2+=mod;
return ((lint)t*t%mod*(p-1)%mod*n%mod*c2+squ((lint)n*n%mod-t*(p-1ll)%mod*n%mod*c1%mod))%mod;
}
int main()
{
int ans=1;
for(int T=inn(),p,c;T;T--) p=inn(),c=inn(),ans=(lint)ans*solve(p,c)%mod;
return !printf("%d\n",ans);
}