洛谷P4721 分治FFT(多项式求逆模板+生成函数)

https://www.luogu.org/problem/P4721

这道题不算是一道裸的多项式求逆模板。

首先这道题题目说的是分治FFT,我反正弄了几天,开始学FFT就看到这道题了,没有弄出来,到处瞎想,太弱了。

随着知识的深入后来发现这是一道模板题。

其实这道题和生成函数应该不是很大,即使你不知道,我想可能乱搞也会出来。。。。。

我们来开始推导一下。

首先有这个式子:

f[i]=\sum_{j=1}^{i}f[i-j]g[j]

然后我们知道一个序列可以表示成多项式。我们设生成函数F(x)和G(x):

F(x)=\sum_{i=0}^{\infty }f[i]x^{i}

G(x)=\sum_{i=0}^{\infty }g[i]x^{i}

我们对于没有的地方全部规定为零。(和信号与系统的离散信号那套差不多)

这里g[0]没有说明,所以为0,其他没有说明的也为零。

我们开始做卷积:

F(x)*G(x)=\sum_{i=0}^{\infty }\sum_{j=0}^{\infty}(f[i]\cdot g[j])x^{i+j}

我们在变一下:

F(x)*G(x)=\sum_{k=0}^{\infty}x^{k}\sum_{i+j=k}f[i]g[j]=\sum_{k=0}^{\infty}x^{k}\sum_{i=0}^{k}f[k-i]g[i]

我们发现后面的一项很熟悉,因为g[0]等于0,当k=0时后面的一部分等于0,当k>0时,就等于f[k].

卷积就直接等于:

F(x)*G(x)=\sum_{k=1}^{\infty}f[k]x^{k}

和F(x)就差了一项f(0)x^0=f(0),又因为f(0)等于1所以在变一下:

F(x)*G(x)+1=F(x)

F(x)=\frac{1}{1-G(x)}

就是多项式求逆了。

下面稍微联想一下信号与系统上关于离散卷积的运算:

根据离散卷积的公式:y[n]=(x*h)[n]=\sum_{i=-\infty}^{\infty}x[i]h[n-i]=\sum_{i=-\infty}^{\infty}x[n-i]h[i]

吧g和f当成由于这都是正时间轴上的序列,其他时域不存在信号因此其他算进去了也无意义,因此带入上面的公式新的序列

y[n]=\sum_{i=0}^{n}f[n-i]g[i]因此这样就轻松得到了上面的卷积答案,并且还知道了,新序列y的开始地方就是g和h开始的坐标相加。

其实我就是这样推的。。。。。根本没有用生成函数(电子专业)。

剩下的就是多项式求逆,就是一个模板,也帖一个模板方便自己。

#include "bits/stdc++.h"

using namespace std;
const double eps = 1e-6;
#define reg register
#define lowbit(x) x&-x
#define pll pair<ll,ll>
#define pii pair<int,int>
#define fi first
#define se second
#define makp make_pair
#define cp complex<double>

int dcmp(double x) {
    if (fabs(x) < eps) return 0;
    return (x > 0) ? 1 : -1;
}

typedef long long ll;
typedef unsigned long long ull;
const ull hash1 = 201326611;
const ull hash2 = 50331653;
const ll N = 280000 + 10;
const int M = 1000000;
const int inf = 0x3f3f3f3f;
const ll mod = 998244353;
const double PI = acos(-1.0);
ll Mod(ll x) {
    if (x >= mod) x -= mod;
    return x;
}

ll quick(ll a, ll n) {
    ll ans = 1;
    while (n) {
        if (n & 1) ans = ans * a % mod;
        a = a * a % mod;
        n >>= 1;
    }
    return ans;
}

ll r[N], g, tot, lim;
void init(int len) {
    tot = 1, lim = 0;
    while (tot < 2 * len) tot <<= 1, lim++;
    for (int i = 0; i < tot; i++) {
        r[i] = (r[i >> 1] >> 1) | ((i & 1) << (lim - 1));
    }
}

void ntt(ll *a, int tot, int inv) {
    for (int i = 0; i < tot; i++) {
        if (i < r[i]) swap(a[i], a[r[i]]);
    }
    for (int l = 2; l <= tot; l <<= 1) {
        ll tmp = quick(g, (mod - 1) / l);
        if (inv) tmp = quick(tmp, mod - 2);
        int m = l / 2;
        for (int j = 0; j < tot; j += l) {
            ll w = 1;
            for (int i = 0; i < m; i++) {
                ll t = 1LL * a[j + i + m] * w % mod;
                a[j + i + m] = Mod(a[j + i] - t + mod);
                a[j + i] = Mod(a[j + i] + t);
                w = 1LL * w * tmp % mod;
            }
        }
    }
    if (inv) {
        ll t = quick(tot, mod - 2);
        for (int i = 0; i < tot; i++) {
            a[i] = 1LL * a[i] * t % mod;
        }
    }
}
int n;
ll a[N], b[N], c[N];
void solve(int len, ll *a, ll *b) {
    if (len == 1) {
        b[0] = quick(a[0], mod - 2);
        return;
    }
    solve((len + 1) >> 1, a, b);
    init(len);
    for (int i = 0; i < len; i++) c[i] = a[i];
    for (int i = len; i < tot; i++) c[i] = 0;
    ntt(c, tot, 0);
    ntt(b, tot, 0);
    for (int i = 0; i < tot; i++) {
        b[i] = 1LL * (2 - 1LL * c[i] * b[i] % mod + mod) % mod * b[i] % mod;
    }
    ntt(b, tot, 1);
    for (int i = len; i < tot; i++) b[i] = 0;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i < n; i++) {
        scanf("%lld", &a[i]);
        a[i] = -a[i];
    }
    a[0] = 1;
    g = 3;
    solve(n, a, b);
    for (int i = 0; i < n; i++)
        printf("%lld ", b[i]);
    return 0;
}
发布了130 篇原创文章 · 获赞 80 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/KXL5180/article/details/98656697