BSGS及扩展BSGS算法及例题

\(BSGS(baby-step-giant-step)\)算法是用来解高次同余方程的最小非负整数解的算法,即形如这个的方程:

\(a^x\equiv b(mod\ p)\)

其中 \(p\)为质数(其实只要( \((a,p)=1\)即可)
首先考虑暴力怎么解:由费马小定理可知 \(a^{p-1}\equiv 1(mod\ p)\),也就是说如果在 \([0,p-1]\)内无解的话,方程就是无解的。所以我们从小到大枚举 \([0,p-1]\)中的每一个数,满足方程就结束。但是这里 \(p-1\)并不一定是最小正周期,这个可以由 的定义推出,有兴趣的同学可以去看一下。
但是这个方法在 \(p\)很大时就 \(GG\)啦,于是我们考虑优化一下:
\(x=im-j\),其中 \(m=\left \lceil \sqrt{p} \right \rceil\)。那么我们开始对方程进行变形:
\(a^{im}\equiv a^jb(mod\ p)\)

显然, \(j\)是余数,所以有 \(0\leqslant j \leqslant m-1\),也就是说右边的值最多只有 \(m\)个,而 \(i\)的最大值也只为 \(m\)。所以我们暴力枚举 \(j\),把 \(a^jb\)存到 \(map\)或哈希表里,再从小到大暴力枚举 \(i\),每次在哈希表里查一下有没有 \(a^jb\)使得方程成立就可以辣。
同时,有一个很妙的事情:我们枚举 \(j\)时是从小到大枚举的,每次插入到哈希表头的前面,所以表的越靠前的元素的 \(j\)值是越大的,而我们正好想要 \(j\)尽量大,一石二鸟,每次查到一个值就可以直接 \(return\)了。这样查询的近似复杂度就降到 \(O(1)\)了。
要注意先特判一下 \(x=0\)的情况。
板子代码:

namespace Ha { //哈希表
    int tot, h[MOD+5], ne[MOD+5];
    ll ha[MOD+5];
    void insert(ll x, ll num) { //插入操作
        ll t = num%MOD;
        p[++tot] = x, ha[tot] = num, ne[tot] = h[t], h[t] = tot;
    }
    ll query(ll tar) { //查询操作
        for(int i = h[tar%MOD]; i != -1; i = ne[i])
            if(ha[i] == tar) return p[i];
        return -1;
    }
}
using namespace Ha;
ll bsgs(ll a, ll b, ll p) {
    a %= p, b %= p;
    if(a == 0 && b != 0) return -1; //a%p==0时显然无解
    if(a == 0 && b == 0) return 1;
    if(b == 1) return 0;
    ll m = ceil(sqrt((double)p)), q = 1, x = 1;
    memset(h, -1, sizeof h); //记得清空
    for(ll j = 0; j < m; ++j) insert(j, q*b%p), q = q*a%p; //暴力枚举j并存入表中
    for(ll i = 1, j; i <= m; ++i) {
        x = x*q%p, j = query(x); //在表中找
        if(j != -1) return i*m-j; //找到解了,直接返回
    }
    return -1;
}

但是\(p\)\(a\)不互质了怎么办呢,这时候就要请上我们的扩展\(BSGS\)了,虽说是扩展,其实并不是难以理解。
首先扔一个定理:
若有\(p\equiv q(mod\ r)\),令\(d\)\(p,q,r\)的共同正因子,则有\(\frac{p}{d}\equiv \frac{q}{d}(mod\ \frac{r}{d})\)
证明如下:
易知\(p-q=kr\),两边同除\(d\),推出\(\frac{p}{d}-\frac{q}{d}=k\frac{r}{d}\),即\(\frac{p}{d}\equiv \frac{q}{d}(mod\ \frac{r}{d})\)
有了这个定理,我们就又双叒叕可以做这道题了:首先设\(t=1, cnt=0\),若\(a,c\)不互质,就令\(d\)为它们的的最大公约数,判断一下\(b\ mod\ d\)等不等于0,不等于零,就无解,否则令\(b\)等于\(\frac{b}{d}\)\(c\)等于\(\frac{c}{d}\),把\(\frac{a}{d}\)累乘到\(t\)上,让\(cnt\)自加。重复以上操作,直到\(a,c\)互质,然后我们会得到一个差不多长这样的方程:

\(ta^{x-cnt}\equiv b'(mod\ c')\),其中 \(a,c'\)互质

然后我们就将它转化为标准的 \(BSGS\)可解决的问题啦。
记得特判一下 \(x\leqslant cnt\)的情况( 为什么留给读者自己思考
粘一下代码:

ll bsgs(ll a, ll b, ll p) {
    a %= p, b %= p;
    if(a == 0 && b != 0) return -1;
        if(a == 0 && b == 0) return 1;
        if(b == 1) return 0;
    ll d = 1, t = 1, cnt = 0, m = ceil(sqrt(p)), q = 1;
    while((t = gcd(a, p)) != 1) {
        if(b%t) return -1;
        cnt++, b /= d, p /= d, t = t*(a/d)%p;
        if(b == d) return cnt; //特判
    }
    memset(h, -1, sizeof h);
    for(ll i = 0; i < m; ++i) insert(i, q*b%p), q = q*a%p;
    for(ll i = 1, j; i <= m+1; ++i) {
        t = t*q%MOD;
        j = query(t);
        if(j != -1) return i*m-j+cnt; //返回值要加上cnt
    }
    return -1;
}

放一道例题:洛谷/BZOJ。这应该算是板子了吧→_→
AC代码:

#include <bits/stdc++.h>

#define MOD 1000007
#define ll long long

int T, k;
ll p[MOD+5];

namespace Ha {
    int tot, h[MOD+5], ne[MOD+5];
    ll ha[MOD+5];
}

using namespace std;
using namespace Ha;

void insert(ll x, ll num) {
    ll t = num%MOD;
    p[++tot] = x, ha[tot] = num, ne[tot] = h[t], h[t] = tot;
}

ll query(ll tar) {
    for(int i = h[tar%MOD]; i != -1; i = ne[i])
        if(ha[i] == tar) return p[i];
    return -1;
}

ll bsgs(ll a, ll b, ll p) {
    a %= p, b %= p;
    if(a == 0) return -1;
    if(b == 1) return 0;
    ll m = ceil(sqrt((double)p)), q = 1, x = 1;
    memset(h, -1, sizeof h);
    for(ll i = 0; i < m; ++i) insert(i, q*b%p), q = q*a%p;
    for(ll i = 1, j; i <= m; ++i) {
        x = x*q%p, j = query(x);
        if(j != -1) return i*m-j;
    }
    return -1;
}

ll fpow(ll x, ll p, ll mod) {
    ll base = x%mod, ret = 1LL;
    while(p) {
        if(p&1) ret = ret*base%mod;
        base = base*base%mod;
        p >>= 1;
    }
    return ret;
}

ll inv(ll x, ll p) {
    return fpow(x, p-2, p);
}

int main() {
    cin >> T >> k;
    ll x, y, z, inv_y, ans;
    while(T--) {
        cin >> x >> y >> z;
        if(k == 1) cout << fpow(x, y, z) << endl;
        else if(k == 2) {
            if(x%z == 0) cout << "Orz, I cannot find x!" << endl;
            else cout << y*inv(x, z)%z << endl;
        }
        else {
            ans = bsgs(x, y, z);
            if(ans == -1) cout << "Orz, I cannot find x!" << endl;
            else cout << ans << endl;
        }
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/dummyummy/p/9770058.html