[Luogu P4239 P4239] 【模板】多项式求逆(加强版)

版权声明:欢迎转载蒟蒻博客,但请注明出处: https://blog.csdn.net/LPA20020220/article/details/85163925

洛谷传送门(标准版)

洛谷传送门(加强版)

题目描述

给定一个多项式 F ( x ) F(x) ,请求出一个多项式 G ( x ) G(x) , 满足 F ( x ) G ( x ) 1 ( m o d   x n ) F(x) * G(x) \equiv 1 ( \mathrm{mod\:} x^n ) 。系数对 998244353 / 1 0 9 + 7 998244353/10^9+7 取模。

输入输出格式

输入格式:

首先输入一个整数 n n , 表示输入多项式的次数。
接着输入 n n 个整数,第 i i 个整数 a i a_i ​ 代表 F ( x ) F(x) 次数为 i 1 i-1 项的系数。

输出格式:

输出 n n 个数字,第 i i 个整数 b i b_i 代表 G ( x ) G(x) 次数为 i 1 i-1 的项的系数。

输入输出样例

输入样例#1:

5
1 6 3 4 9

输出样例#1:

1 1000000001 33 999999823 1020

说明

1 n 1 0 5 , 0 a i 1 0 9 1 \leq n \leq 10^5, 0 \leq a_i \leq 10^9

解题分析

假设我们求出了 G ( x ) G'(x) 使得 F ( x ) G ( x ) 1 ( m o d   x n 2 ) F(x)*G'(x)\equiv 1(mod\ x^{\lceil\frac{n}{2})\rceil} ,显然也有 F ( x ) G ( x ) 1 ( m o d   x n 2 ) F(x)*G(x)\equiv 1(mod\ x^{\lceil\frac{n}{2})\rceil} 。 所以有:
G ( x ) G ( x ) 0 ( m o d   x n 2 ) G(x)-G'(x)\equiv 0(mod\ x^{\lceil\frac{n}{2} \rceil})
平方一下就有:
G 2 ( x ) 2 G ( x ) G ( x ) + G 2 ( x ) 0 ( m o d   x n ) G^2(x)-2G'(x)G(x)+G'^2(x)\equiv 0(mod\ x^n)
乘上 F ( x ) F(x) 就有:
G ( x ) 2 G ( x ) + F ( x ) G 2 ( x ) 0 ( m o d   x n ) G(x)-2G'(x)+F(x)G'^2(x)\equiv 0(mod\ x^n)
显然就可以递归先算出 G ( x ) G'(x) , 再做两次多项式乘法即可。

边界条件: n = 1 n=1 G ( x ) = i n v ( F [ 0 ] ) G'(x)=inv(F[0])

代码如下:

#include <cstdio>
#include <cstring>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define ll long long
#define MOD 998244353
#define G 3
#define Ginv 332748118
#define MX 400500
template <class T>
IN void in(T &x)
{
	x = 0; R char c = gc;
	for (; !isdigit(c); c = gc);
	for (;  isdigit(c); c = gc)
	x = (x << 1) + (x << 3) + c - 48;
}
int l;
int a[MX], b[MX], res[MX], rev[MX], ind[MX];
IN int fpow(R int base, R int tim)
{
	int ret = 1;
	W (tim)
	{
		if (tim & 1) ret = 1ll * ret * base % MOD;
		base = 1ll * base * base % MOD, tim >>= 1;
	}
	return ret;
}
IN void NTT(int *dat, R int typ, R int len)
{
	for (R int i = 0; i < len; ++i) if (rev[i] > i) std::swap(dat[i], dat[rev[i]]);
	R int seg, now, cur, step, bd, buf1, buf2, deal, base;
	for (seg = 1; seg < len; seg <<= 1)
	{
		base = fpow(typ ? G : Ginv, (MOD - 1) / (seg << 1)); step = seg << 1;
		for (now = 0; now < len; now += step)
		{
			deal = 1, bd = now + seg;
			for (cur = now; cur < bd; ++cur, deal = 1ll * deal * base % MOD)
			{
				buf1 = dat[cur], buf2 = 1ll * dat[cur + seg] * deal % MOD;
				dat[cur] = (buf1 + buf2) % MOD, dat[cur + seg] = (buf1 - buf2 + MOD) % MOD;
			}
		}
	}
	if (typ) return; int inv = fpow(len, MOD - 2);
	for (R int i = 0; i < len; ++i) dat[i] = 1ll * dat[i] * inv % MOD;
}
void getinv(R int up)
{
	if (up == 1) return res[0] = fpow(ind[0], MOD - 2), void();
	int half = up + 1 >> 1, len, lg; getinv(half);
	for (len = 1, lg = 0; len <= (up << 1); ++lg, len <<= 1);
	for (R int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
	for (R int i = 0; i < len; ++i) b[i] = a[i] = 0;
	for (R int i = 0; i < up; ++i) a[i] = ind[i];
	for (R int i = 0; i < half; ++i) b[i] = res[i];
	NTT(a, 1, len), NTT(b, 1, len);
	for (R int i = 0; i < len; ++i) a[i] = (2 * b[i] % MOD - 1ll * a[i] * b[i] % MOD * b[i] % MOD + MOD) % MOD;
	NTT(a, 0, len);
	for (R int i = 0; i < up; ++i) res[i] = a[i];
}
int main(void)
{
	in(l);
	for (R int i = 0; i < l; ++i) in(ind[i]);
	getinv(l);
	for (R int i = 0; i < l; ++i) printf("%d ", res[i]);
}

那么问题来了, 如果模数不为 N T T NTT 质数怎么办?

这里 F F T FFT 做了再取模的主要问题在于卷积大小可能达到 1 0 9 1 0 9 1 0 5 = 1 0 23 10^9*10^9*10^5=10^{23} , 显然会爆 d o u b l e double 的精度。

有个很直接的想法: 把 x x 分成 k × x + p k\times \sqrt{x}+p 来做, 然后对于 k , p k,p 分开做 F F T FFT , 然后合起来即可。

常数巨大, 接近 O ( N l o g 2 ( N ) ) O(Nlog^2(N))

代码如下:

#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cctype>
#include <algorithm>
#include <cstring>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define MX 400500
#define db long double
#define ll long long
#define MOD 1000000007ll
const int BASE = std::sqrt(MOD);
template <class T>
IN void in(T &x)
{
	x = 0; R char c = gc;
	for (; !isdigit(c); c = gc);
	for (;  isdigit(c); c = gc)
	x = (x << 1) + (x << 3) + c - 48;
}
int n;
int ans[MX], ind[MX], c1[MX], c2[MX], res[MX], rev[MX];
namespace Poly
{
	const db PI = std::acos(-1.0);
	struct Complex {db re, im;} a[MX], b[MX], c[MX], d[MX], w[MX];
	IN Complex operator * (const Complex &x, const Complex &y)
	{return {x.re * y.re - x.im * y.im, x.re * y.im + x.im * y.re};}
	IN Complex operator + (const Complex &x, const Complex &y)
	{return {x.re + y.re, x.im + y.im};}
	IN Complex operator - (const Complex &x, const Complex &y)
	{return {x.re - y.re, x.im - y.im};}
	IN int fpow(R int base, R int tim)
	{
		int ret = 1;
		W (tim)
		{
			if (tim & 1) ret = 1ll * ret * base % MOD;
			base = 1ll * base * base % MOD, tim >>= 1;
		}
		return ret;
	}
	IN void FFT(Complex *dat, R int len, R int typ)
	{
		for (R int i = 1; i < len; ++i) if (rev[i] > i) std::swap(dat[rev[i]], dat[i]);
		R int cur, now, seg, bd, step, tag, id; Complex buf1, buf2;
		for (seg = 1; seg < len; seg <<= 1)
		{
			step = seg << 1, tag = len / seg;
			for (now = 0; now < len; now += step)
			{
				bd = now + seg, id = 0;
				for (cur = now; cur < bd; ++cur, id += tag)
				{
					buf1 = dat[cur], buf2 = dat[cur + seg] * (Complex){w[id].re, typ * w[id].im};
					dat[cur] = buf1 + buf2, dat[cur + seg] = buf1 - buf2;
				}
			}
		}
		if (typ > 0) return;
		for (R int i = 0; i < len; ++i) dat[i].re /= len;
	}
	IN void Mul(R int up, int *m1, int *m2, int *res)
	{
		R int len = 1, lg = 0;
		Complex e, f, g, h;
		for (; len <= up; len <<= 1, lg++);
		for (R int i = 0; i < (len << 1); ++i) a[i] = b[i] = c[i] = d[i] = {0, 0};
		for (R int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
		for (R int i = 1; i < len; i <<= 1) for (R int j = 0; j < i; ++j)
		w[len / i * j] = {std::cos(j * PI / i), std::sin(j * PI / i)};
		for (R int i = 0; i < up; ++i)
		{
			m1[i] %= MOD, m2[i] %= MOD;
			a[i] = {m1[i] / BASE, 0};
			b[i] = {m1[i] % BASE, 0};
			c[i] = {m2[i] / BASE, 0};
			d[i] = {m2[i] % BASE, 0};
		}
		FFT(a, len, 1), FFT(b, len, 1), FFT(c, len, 1), FFT(d, len, 1);
		for (R int i = 0; i < len; ++i)
		{
			e = a[i], f = b[i], g = c[i], h = d[i];
			a[i] = e * g, b[i] = e * h, c[i] = f * g, d[i] = f * h;
		}
		FFT(a, len, -1), FFT(b, len, -1), FFT(c, len, -1), FFT(d, len, -1);
		for (R int i = 0; i < up; ++i)
		{
			res[i] = (ll)(a[i].re + 0.5) % MOD * BASE % MOD * BASE % MOD;
			(res[i] += (ll)(b[i].re + 0.5) % MOD * BASE % MOD) %= MOD;
			(res[i] += (ll)(c[i].re + 0.5) % MOD * BASE % MOD) %= MOD;
			(res[i] += (ll)(d[i].re + 0.5) % MOD) %= MOD;
		}
	}
	IN void Getinv(R int up)
	{
		if (up == 1) return ans[0] = fpow(ind[0], MOD - 2), void();
		int half = up + 1 >> 1; Getinv(half);
		for (R int i = 0; i < up; ++i) c1[i] = ind[i], c2[i] = ans[i];
		Mul(up, c1, c2, res);
		for (R int i = 0; i < up; ++i) c1[i] = res[i];
		Mul(up, c1, c2, res);
		for (R int i = 0; i < up; ++i) ans[i] = (2 * ans[i] % MOD - res[i] + MOD) % MOD;
	}
}
int main(void)
{
	in(n);
	for (R int i = 0; i < n; ++i) in(ind[i]);
	Poly::Getinv(n);
	for (R int i = 0; i < n; ++i) printf("%d ", ans[i]);
}

猜你喜欢

转载自blog.csdn.net/LPA20020220/article/details/85163925
今日推荐