uva11916 bsgs算法逆元模板,求逆元,组合计数

其实思维难度不是很大,但是各种处理很麻烦,公式推导到最后就是一个bsgs算法解方程

/*
要给M行N列的网格染色,其中有B个不用染色,其他每个格子涂一种颜色,同一列上下两个格子不能染相同的颜色
涂色方案%100000007的结果是R,现在给出R,N,K,请求出最小的M 
对于第一行来说,每个位置有k种选择,那么填色方案数是k^n
对于第二行来说,每个位置有k-1中选择,那么填色方案数时(k-1)^n种
依次类推,如果i+1行的某个格子上面是白格,那么这个格子有k种填色方案

将M行分为两部分,第一部分是固定的,即行数最大的B向下一行,注意特判情况 
第二部分是不固定的,即不停增加行数M,直到求出结果=R 

另P=(K-1)^N,所以方案总数是cnt*P^M=R (mod 100000007) 
P^M = cnt^-1 * R(mod 100000007)
逆元算一下即可 
用bsgs算法 解出这个关于M的方程即可 
*/
#include<bits/stdc++.h>
using namespace std;
#define ll long long 
#define maxn 510
#define mod 100000007 

int n,m,k,b,r,x[maxn],y[maxn];
set<pair<int,int> >best;

ll pow_mod(ll a,ll p){//快速幂 
    ll res=1;
    while(p){
        if(p%2)
            res=res*a%mod;
        p>>=1;
        a=a*a%mod;
    }
    return res;
}
ll exgcd(ll a,ll b,ll &x,ll &y){
    if(b==0){x=1;y=0;return a;}
    ll d=exgcd(b,a%b,y,x);
    y-=a/b*x;
    return d;
}
ll inv(ll a){//ax+y*mod=1 ==> ax=1(mod mod),所以x就是a关于mod的逆元 
    ll d,x,y;
    d=exgcd(a,mod,x,y);
    return d==1?(x+mod)%mod:-1;
}
int log_mod(int a,int b){//bsgs算法,求解a^x=b(mod m)方程 
    int m, v, e = 1, i;
    m = (int)sqrt(mod+0.5);
    v = inv(pow_mod(a, m));
    map<int, int> x;
    x[1] = 0;
    
    for(int j=1;j<m;j++){//建立hash表,x=i*m+j 
        e=(ll)e*a%mod;
        if(!x.count(e))
            x[e]=j;
    } 
    for(int i=0;i<m;i++){
        if(x.count(b))
            return i*m+x[b];
        b=(ll)b*v%mod;//这里实际上是用逆元处理了,即将a^(i*m+j)=b (mod m)转化为a^j=b^(i*m)^(-1) (mod m) 
    } 
    return -1;
} 
int count(){//计算固定部分的方案数 
    int c=0;//统计b块下面的的方格
    for(int i=0;i<b;i++)
        if(x[i]!=m && !best.count(make_pair(x[i]+1,y[i])))
            c++; 
    c+=n; 
    for(int i=0;i<b;i++)
        if(x[i]==1)
            c--;
    
    return pow_mod(k-1,(ll)m*n-b-c)*pow_mod(k,c)%mod;
}
int doit(){
    int cnt=count();//先求出第一部分的cnt
    if(cnt==r)
        return m; 
    
    int c=0;//要把第m+1行单独拿出来考虑 
    for(int i=0;i<b;i++)
        if(x[i]==m)
            c++;
    m++;
    cnt=cnt*pow_mod(k,c)%mod;
    cnt=cnt*pow_mod(k-1,n-c)%mod;
    if(cnt==r)
        return m;
    
    //接下去就只要求对数方程即可
    int P=pow_mod(k-1,n); 
    return log_mod(P,r*inv(cnt)%mod)+m; 
}
int main(){
    int t,cas=1;
    scanf("%d",&t);
    while(t--){
        scanf("%d%d%d%d",&n,&k,&b,&r);
        best.clear();
        m=1;
        for(int i=0;i<b;i++){
            scanf("%d%d",&x[i],&y[i]);
            m=max(x[i],m);
            best.insert(make_pair(x[i],y[i]));
        }
        printf("Case %d: %d\n",cas++,doit());
    }
} 

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/10442487.html