【洛谷4721】【模板】分治FFT(CDQ分治_NTT)

题目:

洛谷 4721

分析:

我觉得这个 “分治 FFT ” 不能算一种特殊的 FFT ,只是 CDQ 分治里套了个用 FFT (或 NTT)计算的过程,二者是并列关系而不是偏正关系,跟 CDQ 分治套树状数组之类性质差不多吧(所以我也不知道为什么洛谷要把这个作为一个模板)。

言归正传,先看一眼原来的式子:

\[f[i]=\begin{cases}1\ (i=0)\\\sum_{j=1}^{i}f[i-j]g[j]\ \mathrm{otherwise}\end{cases}\]

\(f[i]=\sum f[i-j]g[j]\) 很像一个多项式卷积,只是后面的值要用到前面的值,不能直接卷积。考虑 CDQ 分治计算区间 \([l,r]\) 的一般过程:先递归左区间 \([l,mid]\) ,再计算左区间的值对右区间的值的贡献,最后递归右区间 \((mid,r]\)

如何计算 “左区间的值对右区间的值的贡献” 呢?考虑 \(f[i](l\leq i\leq mid)\) 这一项对 \((mid,r]\) 的贡献:

\[f[i+j]=\sum f[i]g[j] (i\in [l, mid], j\in [0, r-i])\]

(注意 \(i+j\leq mid\) 的情况已经在递归左区间时计算过,直接忽略掉即可)

设多项式 \(A[i-l]=\begin{cases}f[i](l\leq i \leq mid)\\0(mid<i\leq r)\end{cases}\) (后半部分置 \(0\) 凑够长度),则有:

\[f[l+i+j]=\sum A[i]g[j] (i\in [0, r-l], j\in [0, r-l-i])\]

写得更清晰一点:

\[f[l+i]=\sum A[j]g[i-j] (i\in [0, r-l], j\in [0, i])\]

这个直接拿卷积算就好了。

代码:

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cctype>
using namespace std;

namespace zyt
{
    template<typename T>
    inline bool read(T &x)
    {
        char c;
        bool f = false;
        x = 0;
        do
            c = getchar();
        while (c != EOF && c != '-' && !isdigit(c));
        if (c == EOF)
            return false;
        if (c == '-')
            f = true, c = getchar();
        do
            x = x * 10 + c - '0', c = getchar();
        while (isdigit(c));
        if (f)
            x = -x;
        return true;
    }
    template<typename T>
    inline void write(T x)
    {
        static char buf[20];
        char *pos = buf;
        if (x < 0)
            putchar('-'), x = -x;
        do
            *pos++ = x % 10 + '0';
        while (x /= 10);
        while (pos > buf)
            putchar(*--pos);
    }
    typedef long long ll;
    const int N = 1e5 + 10, LEN = N << 2, p = 998244353, g = 3;
    inline int power(int a, int b)
    {
        int ans = 1;
        while (b)
        {
            if(b & 1)
                ans = (ll)ans * a % p;
            a = (ll)a * a % p;
            b >>= 1;
        }
        return ans;
    }
    inline int inv(const int a)
    {
        return power(a, p - 2);
    }
    namespace Polynomial
    {
        int rev[LEN], omega[LEN], winv[LEN];
        void init(const int n, const int lg2)
        {
            int w = power(g, (p - 1) / n), wi = inv(w);
            omega[0] = winv[0] = 1;
            for (int i = 1; i < n; i++)
            {
                omega[i] = (ll)omega[i - 1] * w % p;
                winv[i] = (ll)winv[i - 1] * wi % p;
            }
            for (int i = 0; i < n; i++)
                rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (lg2 - 1)));
        }
        void ntt(int *a, const int *w, const int n)
        {
            for (int i = 0; i < n; i++)
                if (i < rev[i])
                    swap(a[i], a[rev[i]]);
            for (int l = 1; l < n; l <<= 1)
                for (int i = 0; i < n; i += (l << 1))
                    for (int k = 0; k < l; k++)
                    {
                        int x = a[i + k], y = (ll)w[n / (l << 1) * k] * a[i + l + k] % p;
                        a[i + k] = (x + y) % p;
                        a[i + l + k] = (x - y + p) % p;
                    }
        }
        void mul(const int *a, const int *b, int *c, const int n)
        {
            static int x[LEN], y[LEN];
            int m = 1, lg2 = 0;
            while (m < n + n - 1)
                m <<= 1, ++lg2;
            memcpy(x, a, sizeof(int[n]));
            memset(x + n, 0, sizeof(int[m - n]));
            memcpy(y, b, sizeof(int[n]));
            memset(y + n, 0, sizeof(int[m - n]));
            init(m, lg2);
            ntt(x, omega, m), ntt(y, omega, m);
            for (int i = 0; i < m; i++)
                x[i] = (ll)x[i] * y[i] % p;
            ntt(x, winv, m);
            int invm = inv(m);
            for (int i = 0; i < n; i++)
                c[i] = (ll)x[i] * invm % p;
        }
    }
    int arr[N], ans[N], n;
    void solve(const int l, const int r)
    {
        static int tmp1[N], tmp2[N];
        if (l == r)
            return;
        int mid = (l + r) >> 1;
        solve(l, mid);
        for (int i = l; i <= mid; i++)
            tmp1[i - l] = ans[i];
        for (int i = mid + 1; i <= r; i++)
            tmp1[i - l] = 0;
        Polynomial::mul(arr, tmp1, tmp2, r - l + 1);
        for (int i = mid + 1; i <= r; i++)
            ans[i] = (ans[i] + tmp2[i - l]) % p;
        solve(mid + 1, r);
    }
    int work()
    {
        read(n);
        for (int i = 1; i < n; i++)
            read(arr[i]);
        ans[0] = 1;
        solve(0, n - 1);
        for (int i = 0; i < n; i++)
            write(ans[i]), putchar(' ');
        return 0;
    }
}
int main()
{
    freopen("4721.in", "r", stdin);
    return zyt::work();
}

猜你喜欢

转载自www.cnblogs.com/zyt1253679098/p/10534403.html