计算组合数的三种方式

摘要

本文主要介绍计算组合数的三种方式,这三种方式分别适用于不同的数据范围,其中涉及到了数论中的乘法逆元和费马小定理,卢卡斯定理。

递推

当数据范围小于等于 3000 3000 左右时,可直接根据组合数的一个性质: C ( n , m ) = C ( n 1 , m ) + C ( n 1 , m 1 ) C(n, m) = C(n-1,m) + C(n-1, m-1) 直接递归计算,时间复杂度为 O ( n 2 ) O(n^2) .

public static void get_C(int[][] arr){
    for(int i = 1; i <= MAXN; i++){
        for(int j = 1; j <= MAXN; j++){
            if(j == 1) arr[i][j] = i;
            else arr[i][j] = (arr[i-1][j] + arr[i-1][j-1])%1000000007;
        }
    }
}

预处理阶乘

当数据范围小于等于 1 0 5 10^5 左右时,我们要采取更快的方式计算组合数。

组合数公式: C ( n , m ) = C(n, m) = n ! m ! ( n m ) ! {n!}\over{m!(n-m)!}

如果我们可以预处理出所有阶乘,然后根据这个公式就可以在 O ( 1 ) O(1) 的时间复杂度内求出组合数,预处理所有阶乘的时间复杂度是 O ( n ) O(n) , 所以总的时间复杂度就是 O ( n ) O(n)

但是有一点需要注意的是:当结果很大需要对阶乘取模时,除法并不满足取模分配律。

也就是说: ( n ! m ! ( n m ) ! {n!}\over{m!(n-m)!} ) % P \%P \neq n ! % p ( m ! ( n m ) ! ) % p {n!\%p}\over{(m!(n-m)!)\%p}

那怎么办呢,我们既然必须要对其取模,还必须要用阶乘,不用除法怎么算呢?这要感谢伟大的前人发现了乘法逆元这个神奇的东西。可以将除法取模转化为乘法取模,乘法是满足取模分配律的。

那么什么是乘法逆元呢?

乘法逆元

对于一个整数a,如果 a b 1 ( m o d a*b≡1(mod p ) p) ,则在模 p p 的意义下:

a b b a a是b的逆元,b是a的逆元 。但此条件成立的前提是 a p g c d ( a , p ) = 1 a,p互质,即gcd(a,p)=1

这里暂且不说乘法逆元如何证明,只需要知道它怎么求,怎么用就可以了。

乘法逆元的计算方法:

乘法逆元有好几种求法,这里只简绍用费马小定理求逆元。

费马小定理:

如果 a a p p 互质,则有: a p 1 1 ( m o d a^{p-1}≡ 1 ( mod p ) p )

这个式子是不是和乘法逆元的式子特别相似:

a b 1 ( m o d a*b≡1(mod p ) p)
a p 1 1 ( m o d a^{p-1}≡ 1 ( mod p ) p )
对于两个式子,就差了一个 b b

我们可以将 a p 1 1 ( m o d a^{p-1}≡ 1 ( mod p ) p ) 转化为:
a a p 2 1 ( m o d a*a^{p-2}≡ 1 ( mod p ) p )

这样,我们就得到了 a a p p 的乘法逆元就是 a p 2 a^{p-2} 。我们可以用快速幂来求 a p 2 a^{p-2}


好了,知道了乘法逆元怎么用,怎么求,我们就可以来计算 n ! m ! ( n m ) ! {n!}\over{m!(n-m)!} % p \%p 了。

对于 n ! m ! ( n m ) ! {n!}\over{m!(n-m)!} % p \%p ,我们可以预处理出所有阶乘,和所有阶乘%p的逆元。

代码:

public static long qmi(long a, long b, long p){ // 快速幂求逆元
    long res = 1;
    while(b != 0){
        if((b&1) == 1){
            res = res * a % p; 
        }
        b >>= 1;
       a = a * a % p;
    }
    return res;
}
 public static void Init(){
     infact[0] = 1; // 存储逆元
     fact[0] = 1;
     for(int i = 1; i <= 100000; i++){
         fact[i] = fact[i-1] * i % mod; 
         infact[i] = infact[i-1] * qmi(i, mod - 2, mod) % mod;
    }
}

对于 infact[i] = infact[i-1] * qmi(i, mod - 2, mod) % mod;这行代码可能有些同学会有点迷,下面举个例子就能懂了。
例如:
( 3 4 ) p 2 % p = 3 p 2 % p 4 p 2 % p (3*4)^{p-2} \% p=3^{p-2}\%p*4^{p-2}\%p
所以,上述代码显然是成立的。

然后预处理完所有的阶乘和阶乘的逆元之后,就可以直接求组合数了。
设一共有n组数据,求每组数据的组合数。完整代码:

import java.io.*;
import java.util.*;

public class Main{
    static BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
    static BufferedWriter out = new BufferedWriter(new OutputStreamWriter(System.out));
    
    static final int N = 100005, mod = 1000000007;
    static int n;
    static long[] fact = new long[N];
    static long[] infact = new long[N];
    
    public static int Int(String s){
        return Integer.parseInt(s);
    } 
    
    public static long qmi(long a, long b, long p){
        long res = 1;
        while(b != 0){
            if((b&1) == 1){
                res = res * a % p; 
            }
            b >>= 1;
            a = a * a % p;
        }
        return res;
    }
    
    public static void Init(){
        infact[0] = 1;
        fact[0] = 1;
        for(int i = 1; i <= 100000; i++){
            fact[i] = fact[i-1] * i % mod; 
            infact[i] = infact[i-1] * qmi(i, mod - 2, mod) % mod;
        }
    }
    
    public static void main(String[] args)throws Exception{
        n = Int(in.readLine());
        Init();
        
        for(int i = 0; i < n; i++){
            String[] s = in.readLine().split(" ");
            int a = Int(s[0]);
            int b = Int(s[1]);
            out.write((((long)fact[a] * infact[b] % mod * infact[a-b]%mod)%mod) + "\n");
        }
        out.flush();
    }    
}

卢卡斯定理求组合数

如果数据范围非常大,在 1 0 18 10^{18} 内的话,用O(n)的方式也会超时了,我们需要另辟蹊径,伟大的前人又发现了一个神奇的东西:卢卡斯定理

卢卡斯定理: C ( n , m ) % p = C ( n / p , m / p ) C ( n % p , m % p ) % p C(n,m)\%p=C(n/p,m/p)*C(n\%p,m\%p)\%p

时间复杂度是: O ( l o g p ( n ) p ) , p O(log_p(n)*p),p必须是质数

所以当 n , m n,m 小于p时,我们就直接根据 n ! m ! ( n m ) ! {n!}\over{m!(n-m)!} % p \%p 并利用乘法逆元直接计算组合数的值,当 n , m n,m 大于p时我们就递归计算 C ( n / p , m / p ) C ( n % p , m % p ) % p C(n/p,m/p)*C(n\%p,m\%p)\%p 的值

代码:

import java.io.*;
import java.util.*;

public class Main{
    static BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
    static BufferedWriter out = new BufferedWriter(new OutputStreamWriter(System.out));
    
    static final int N = 100005, mod = 1000000007;
    static int n;
    static long[] fact = new long[N];
    static long[] infact = new long[N];
    
    public static long Int(String s){
        return Long.valueOf(s);
    } 
    
    public static long qmi(long a, long b, long p){
        long res = 1;
        while(b != 0){
            if((b&1) == 1){
                res = res * a % p; 
            }
            b >>= 1;
            a = a * a % p;
        }
        return res;
    }
    
    public static long C(long a, long b, long p){
        long res = 1;
        for(long i = 1, j = a; i <= b; i++, j--){
            res = res * j % p;
            res = res * qmi(i, p-2, p) % p;
        }
        return res;
    }
    
    public static long lucas(long a, long b, long p){
        if(a < p && b < p){
            return C(a, b, p);
        }
        else return C(a%p, b%p, p) * lucas(a/p, b/p, p) % p;
    }
    
    public static void main(String[] args)throws Exception{
        n = (int)Int(in.readLine());
        
        for(int i = 0; i < n; i++){
            String[] s = in.readLine().split(" ");
            long a = Int(s[0]);
            long b = Int(s[1]);
            long p = Int(s[2]);
            out.write(lucas(a, b, p) + "\n");
        }
        out.flush();
    }    
}
发布了77 篇原创文章 · 获赞 292 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/GD_ONE/article/details/104953289