[洛谷P4720] [模板] 扩展卢卡斯

题目传送门

求组合数的时候,如果模数p是质数,可以用卢卡斯定理解决。

但是卢卡斯定理仅仅适用于p是质数的情况。

当p不是质数的时候,我们就需要用扩展卢卡斯求解。

实际上,扩展卢卡斯=快速幂+快速乘+exgcd求逆元+质因数分解+crt合并答案+求阶乘,跟卢卡斯定理没什么关系......

如果把模数p分解成p1^k1*p2^k2*...*px^kx的形式,那么我们可以求出c(n,m)分别模每个pi^ki的结果,再用中国剩余定理合并即可。

每个pi^ki一定是互质的,所以用朴素crt就行。

根据组合数的定义,c(n,m)=(n!) / (m!*(n-m)!) ,所以我们只要能想办法求出阶乘,就能再利用exgcd求出逆元,进而求出组合数。

接下来唯一的问题就是怎么快速求出 x! 取模 pi^ki 的结果。

考虑如下的经典样例(据说来自popoqqq):(19!)%(3^2)

19!=1*2*3*4*5*6*7*8*9*10*11*12*13*14*15*16*17*18*19

先把其中的3的倍数提出来,因为求组合数的时候分子分母能约掉。

19!=(1*2*4*5*7*8)*(10*11*13*14*16*17)*(19)*(3*6*9*12*15*18)=(1*2*4*5*7*8)*(1*2*4*5*7*8)*(3*3*3*3*3*3)*(1*2*3*4*5*6)=(1*2*4*5*7*8)^2*19*(3^6)*(1*2*3*4*5*6)。

后面的6!部分可以递归求解,递归终点为0!=1。

3^6最后计算组合数的时候再处理。

那几个(1*2*4*5*7*8)显然是循环的,循环节长度小于pi^ki,可以暴力计算。

显然一共有(x/(pi^ki))个循环节,套个快速幂即可。

剩下的部分,即19,长度等于x%(pi^ki),也小于pi^ki,也可以暴力计算。

至此我们求出了阶乘。

求组合数的时候,考虑pi的倍数的影响。

分子分母分别计数相加减。

最后用crt合并即可。

 1 #include<cstdio>
 2 typedef long long ll;
 3 
 4 ll n,m,p;
 5 
 6 ll ksm(ll b,ll tp,ll mod)
 7 {
 8     ll ret=1;
 9     while(tp)
10     {
11         if(tp&1)ret=ret*b%mod;
12         b=b*b%mod;
13         tp>>=1;
14     }
15     return ret;
16 }
17 
18 ll mul(ll a,ll b,ll mod)
19 {
20     ll ret=0;
21     while(b)
22     {
23         if(b&1)ret=(ret+a)%mod;
24         a=(a+a)%mod;
25         b>>=1;
26     }
27     return ret;
28 }
29 
30 ll exgcd(ll a,ll b,ll &x,ll &y)
31 {
32     if(!b)
33     {
34         x=1;y=0;
35         return a;
36     }
37     ll t=exgcd(b,a%b,y,x);
38     y-=a/b*x;
39 }
40 
41 ll inv(ll x,ll mod)
42 {
43     ll a,b;
44     exgcd(x,mod,a,b);
45     return (a%mod+mod)%mod;
46 }
47 
48 ll fac(ll x,ll pi,ll pk)
49 {
50     if(!x)return 1;
51     ll ans=1;
52     for(ll i=2;i<=pk;i++)
53         if(i%pi)ans=ans*i%pk;
54     ans=ksm(ans,x/pk,pk);
55     for(ll i=2;i<=x%pk;i++)
56         if(i%pi)ans=ans*i%pk;
57     return ans*fac(x/pi,pi,pk)%pk;
58 }
59 
60 ll c(ll cn,ll cm,ll pi,ll pk)
61 {
62     if(cm>cn)return 0;
63     ll up=fac(cn,pi,pk),d1=fac(cm,pi,pk),d2=fac(cn-cm,pi,pk);
64     ll cnt=0;
65     for(ll i=cn;i;i/=pi)cnt+=i/pi;
66     for(ll i=cm;i;i/=pi)cnt-=i/pi;
67     for(ll i=cn-cm;i;i/=pi)cnt-=i/pi;
68     return up*inv(d1,pk)%pk*inv(d2,pk)%pk*ksm(pi,cnt,pk)%pk;
69 }
70 
71 ll crt(ll a,ll pk)
72 {
73     return a*inv(p/pk,pk)%p*(p/pk)%p;
74 }
75 
76 int main()
77 {
78     scanf("%lld%lld%lld",&n,&m,&p);
79     ll tp=p,ans=0;
80     for(ll i=2;i*i<=p;i++)
81     {
82         if(tp%i)continue;
83         ll pk=1;
84         while(!(tp%i))tp/=i,pk*=i;
85         ans=(ans+crt(c(n,m,i,pk),pk))%p;
86     }
87     if(tp>1)ans=(ans+crt(c(n,m,tp,tp),tp))%p;
88     printf("%lld",(ans%p+p)%p);
89     return 0;
90 }

猜你喜欢

转载自www.cnblogs.com/eternhope/p/9898494.html