洛谷P4238 【【模板】多项式求逆】

好题,证明的过程真的很妙。

正文部分:

题意:
有两个多项式\(A(x)\)\(B(x)\),满足\(A(x)*B(x)≡1(modx^n)\)给出\(A(x)\),求\(B(x)\)

\(A(x)*B'(x)≡1(modx^{ceil{n/2}})\)
观察得知,这两者可相减,于是得到这个式子:
\(B(x)-B'(x)≡0(modx^{ceiln/2})\)

将两边平方得到:
\(B(x)^2+B'(x)^2-2BB'(x)≡0(modx^{n})\)
不妨都乘\(A\),则:
\(B(x) +AB'(x)^2-2B'(x)≡0(modx^n)\)
移项得知:
\(B(x)=2B'(x)-AB'(x)^2\)
同时提\(B'\)

则:\(B(x)=B'(x)(2-AB'(x))\)
于是我们就找到了递归式,从上往下递归即可.

My Code

#include <bits/stdc++.h>
#define il inline
#define gc getchar
#define pc putchar
typedef int mi;
#define int long long
typedef long long LL;
const int MAXN = 1e6 + 10;
const LL p = 998244353;
using namespace std;
int n,m,i,j,k;
int R[MAXN],a[MAXN],b[MAXN],wn[MAXN];
namespace IO {
    il int read() {
        int res = 0;char c;bool sign = 0;
        for(c = gc();!isdigit(c);c = gc()) sign |= c == '-';
        for(;isdigit(c);c = gc()) res = (res << 1) + (res << 3) + (c ^ 48);
        return sign ? -res : res;
    }
}
using IO::read;
il int fpow(int a,int b,int mod = p) {
    int res = 1;
    while(b) {
        if(b & 1) res = (res * a) % mod;
        b >>= 1;a = (a * a) % mod;
    }
    return res;
}
il void ntt(LL a[],int len,int on) {
   for(int i = 0, j = 0; i < len; ++i) {
        if (i > j) swap(a[i], a[j]);
        for (int l = len >> 1; (j ^= l) < l; l >>= 1);
    }
    int id = 0;
    for(int i = 1;i < len;i <<= 1) {
        id++;
        for(int j = 0;j < len;j += i << 1) {
            LL w = 1;
            for(int k = 0;k < i;k++) {
                int x = a[j + k] % p;
                int y = w * a[j + k + i] % p;
                a[j + k] = (x + y) % p;
                a[j + k + i] = (x - y + p) % p;
                w = w * wn[id] % p;
            }
        }
    }
    if(!~on) {
        reverse(a + 1,a + len);
        int inv = fpow(len,p - 2,p);
        for(int i = 0;i < len;i++) a[i] = a[i] * inv % p;
    }
    return;
} 
void calc(int deg,LL *a,LL *b) {
    static LL tmp[MAXN];
    if(deg == 1) b[0] = fpow(a[0],p - 2);
    else {
        calc((deg + 1) >> 1,a,b);
        int _p = 1;
        while(_p < deg << 1) _p <<= 1;
        copy(a,a + deg,tmp);
        fill(tmp + deg,tmp + _p,0);
        ntt(tmp,_p,1);ntt(b,_p,1);
        for(int i = 0;i < _p;i++) {
            b[i] = (2 - b[i] * tmp[i] % p + p) % p * b[i] % p; 
        }
        ntt(b,_p,-1);fill(b + deg,b + _p,0);
    }
    return;
}
mi main() {
    for(int i = 0;i < 21;i++) wn[i] = fpow(3,(p - 1) / (1 << i),p);
    n = read();
    for(int i = 0;i < n;i++) a[i] = read();
    int li = 1,t = 0;while(li < n << 1) li <<= 1,t++;
    calc(n,a,b);
    for(int i = 0;i < n;i++) printf("%lld ",(b[i] + p) % p);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Sai0511/p/10360589.html
今日推荐