多项式各种操作

前些天学了一堆多项式的算法,今天总结一下。

乘法

朴素的NTT就不说了,主要说一下三模数NTT。
三模数NTT,顾名思义,就是选取三个适合做NTT的模数,然后把它们用CRT合并起来得到的答案再去对我们要求的模数取模。
为了方便,这三个模数分别是998244353, 1004535809和469762049。它们的原根都是3,且它们减去1的值都有超过20个2的因子。
如果你觉得能用int128的话就请忽略下面所有的表述
但是,在用CRT合并的时候,我们遇到一个麻烦:这三个模数的乘积太大了。我们令三个模数分别为p1, p2, p3,对三个模数取模得到的答案分别为a1, a2, a3,不做取模的原本的答案为Ans,则有:

A n s a 1   ( m o d   p 1 )

A n s a 2   ( m o d   p 2 )

A n s a 3   ( m o d   p 3 )

先用CRT合并前两个模数的答案,得到
A n s A   ( m o d   M )

A n s a 3   ( m o d   p 3 )


A n s = k M + A

那么
k M + A a 3   ( m o d   p 3 )


k ( a 3 A ) M   ( m o d   p 3 )

此时k, M, A我们都已经得到,就直接对我们要求的模数取模就好了,即
A n s   %   p = ( ( k   %   p ) ( M   %   p ) + A )   %   p

Code

inline void exntt(int *a, int *b)
{
    int len = 1, bit = 0;
    while (len < n + m) len <<= 1, bit++;
    for (int i = 0; i < len; i++) rev[i] = rev[i >> 1] >> 1 | (i & 1) << bit - 1;
    for (int i = 0; i < 3; i++) {
        for (int j = 0; j < len; j++) c[j] = a[j], d[j] = b[j];
        ntt(c, len, 1, p[i]);
        ntt(d, len, 1, p[i]);
        for (int j = 0; j < len; j++) ans[i][j] = 1ll * c[j] * d[j] % p[i];
        ntt(ans[i], len, -1, p[i]);
    }
    for (int i = 0; i < n + m - 1; i++) {
        ll A = (mul(1ll * ans[0][i] * p[1] % M, ksm(p[1] % p[0], p[0] - 2, p[0]), M) +
                mul(1ll * ans[1][i] * p[0] % M, ksm(p[0] % p[1], p[1] - 2, p[1]), M)) % M;
        ll k = ((ans[2][i] - A) % p[2] + p[2]) % p[2] * ksm(M % p[2], p[2] - 2, p[2]) % p[2];
        printf("%lld ", ((k % mod) * (M % mod) % mod + A % mod) % mod);
    }
}

逆元

多项式求逆是个好东西,基本上所有除了乘法外的东西都要用到求逆。
我们要求

A B 1   m o d   x n

假如我们已经求得
A B 1   m o d   x n 2

则有
B B   m o d   x n 2

B 移到左边然后平方得到
B 2 2 B B + B 2 0   m o d   x n

B 2 不好处理,所以同乘以一个A,即
B 2 B + A B 2 0   m o d   x n

我们就得到了
B 2 B + A B 2   m o d   x n

于是倍增求即可。

Code

inline void poly_inv(int *a, int *b, int n)
{
    if (n == 1) {
        b[0] = ksm(a[0], mod - 2);
        return;
    }
    poly_inv(a, b, n + 1 >> 1);
    int len = 1;
    while (len < n << 1) len <<= 1;
    for (int i = 0; i < n; i++) tmp[i] = a[i];
    for (int i = n; i < len; i++) tmp[i] = 0;
    ntt(tmp, len, 1);
    ntt(b, len, 1);
    for (int i = 0; i < len; i++)
        b[i] = 1ll * b[i] * (2 - 1ll * b[i] * tmp[i] % mod + mod) % mod;
    ntt(b, len, -1);
    for (int i = n; i < len; i++) b[i] = 0;
}

开方

多项式开方的思路和求逆一样,都是倍增的利用上一次的答案求下一次的。
简单推一下式子吧:

B 2 A   m o d   x n

B 2 A   m o d   x n 2

B 2 B 2 0   m o d   x n 2

这里出现了两组解,我们只保留
B B 0   m o d   x n 2

还是套路的平方
B 2 2 B B + B 2 0   m o d   x n

其实 B 2 就是 A
A 2 B B + B 2 0   m o d   x n

那么最后 B 就求出来了
B A + B 2 2 B   m o d   x n

Code

void poly_sqrt(int *a, int *b, int n)
{
    if (n == 1) {
        b[0] = 1; //一般情况下a[0] = 1
        return;
    }
    poly_sqrt(a, b, n + 1 >> 1);
    int len = 1;
    while (len < n << 1) len <<= 1;
    memset(c, 0, sizeof c);
    poly_inv(b, c, n);
    for (int i = 0; i < n; i++) tmp[i] = a[i];
    for (int i = n; i < len; i++) tmp[i] = 0;
    ntt(tmp, len, 1);
    ntt(b, len, 1);
    ntt(c, len, 1);
    for (int i = 0; i < len; i++) b[i] = 1ll * (tmp[i] + 1ll * b[i] * b[i] % mod) * c[i] % mod * inv2 % mod;
    ntt(b, len, -1);
    for (int i = n; i < len; i++) b[i] = 0;
}

除法和取模

已经有了求逆,那么多项式除法和取模又是什么?和逆元有什么区别吗?
举个例子:
由小学数学得到

157 12 = 13 , 157 1   m o d   12

然而, 12 13 并不等于 157 ,所以 157 乘以 12 的逆元也并不等于 13 ,这就是除法和求逆的区别。
那除法怎么求呢?
157 , 12 , 13 , 1 都看成多项式,我们发现是取模的结果 1 产生了上述影响。
如果我们把被除数多项式( 157 )、除数多项式( 12 )和商多项式( 13 )都翻转一下,就可以得到
21 31 = 651 , 751 651 = 100

我们发现,此时取模的结果 1 也被翻转成了100,而100的位数较高,不会产生导致除法和求逆结果不同的影响。
所以,我们只要用 751 去乘上 21 的逆元就可以得到 31 ,然后把 31 翻转回来就得到了 13
什么?你说取模?
你都求出了 13 ,你还不会用 157 12 13 来求 1 吗?

Code

inline void poly_div(int *a, int *b, int *d, int *r, int n, int m)
{
    for (int i = 0; i < n; i++) a1[i] = a[n - i - 1];
    for (int i = 0; i < m; i++) b1[i] = b[m - i - 1];
    int l = n - m + 1;
    poly_inv(b1, d, l);
    int len = 1;
    while (len < n + l) len <<= 1;
    ntt(a1, len, 1);
    ntt(d, len, 1);
    for (int i = 0; i < len; i++) d[i] = 1ll * a1[i] * d[i] % mod;
    ntt(d, len, -1);
    for (int i = l; i < len; i++) d[i] = 0;
    for (int i = 0; i < l >> 1; i++) swap(d[i], d[l - i - 1]);
    for (int i = 0; i < m; i++) r[i] = b[i];
    for (int i = 0; i < l; i++) b1[i] = d[i];
    for (int i = l; i < m; i++) b1[i] = 0;
    ntt(r, len, 1);
    ntt(b1, len, 1);
    for (int i = 0; i < len; i++) r[i] = 1ll * r[i] * b1[i] % mod;
    ntt(r, len, -1);
    for (int i = 0; i < n; i++) r[i] = (a[i] - r[i] + mod) % mod;
}

求ln

exp

这两个不想写了……留坑待填

猜你喜欢

转载自blog.csdn.net/star_city_7/article/details/81143378
今日推荐