[Luogu P5050] 【模板】多项式多点求值

版权声明:欢迎转载蒟蒻博客,但请注明出处: https://blog.csdn.net/LPA20020220/article/details/85240439

洛谷传送门

题目描述

给定一个 n n 次多项式 f ( x ) f(x) ,现在请你对于 i [ 1 , m ] i \in [1,m] ,求出 f ( a i ) f(a_i)

输入输出格式

输入格式:

第一行两个正整数 n , m n,m 表示多项式的次数及你要求的点值的数量。

第二行 n + 1 n+1 个非负整数,由低到高地给出多项式的系数。

第三行 m m 个非负整数,表示 a i a_i

输出格式:

一共 m m 行,每行 1 1 个非负整数。

i i 行的数字表示 f ( a i ) f(a_i)

答案对 998244353 998244353 取模。

输入输出样例

输入样例#1:

10 10
18 2 6 17 7 19 17 6 2 12 14
4 15 5 20 2 6 20 12 16 5

输出样例#1:

18147258
804760733
161737928
73381527
23750
973451550
73381527
525589927
842520242
161737928

说明

n , m [ 1 , 64000 ] n,m \in [1,64000] a i , [ x i ] f ( x ) [ 0 , 998244352 ] a_i,[x^i]f(x) \in [0,998244352] [ x i ] f ( x ) [x^i]f(x) 表示 f ( x ) f(x) i i 次项系数

解题分析

很妙的做法。(代码也长的一…)

对于 a 1 , a 2 , . . . , a n a_1,a_2,...,a_n , 我们将其分为两部分, 构造两个多项式 g ( x ) = ( x a 1 ) ( x a 2 ) . . . ( x a n 2 ) , h ( x ) = ( x a n 2 + 1 ) . . . ( a n ) g(x)=(x-a_1)(x-a_2)...(x-a_{\lfloor \frac{n}{2}\rfloor}),h(x)=(x-a_{\lfloor\frac{n}{2}\rfloor+1})...(a_n)

然后我们发现, 对于 i [ 1 , n 2 ] i\in[1,\lfloor\frac{n}{2}\rfloor] g ( a i ) 0 ( m o d   x n ) g(a_i)\equiv 0(mod\ x^n) , 所以我们直接将原多项式对 g ( x ) g(x) 取模, 递归处理即可。 同样对右半部分进行处理。

计算 g ( x ) g(x) h ( x ) h(x) 先用分治 N T T NTT 处理, 存在 v e c t o r vector 中。

注意随时清空辅助的数组!! 否则极可能会莫名其妙 W A WA !!

总复杂度 O ( N l o g 2 ( N ) ) O(Nlog^2(N))

代码如下:

#include <cstdio>
#include <cstring>
#include <cctype>
#include <cassert>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <vector>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define MX 400500
#define ll long long
#define MOD 998244353
#define g 3
#define Ginv 332748118
template <class T>
IN void in(T &x)
{
	x = 0; R char c = gc;
	for (; !isdigit(c); c = gc);
	for (;  isdigit(c); c = gc)
	x = (x << 1) + (x << 3) + c - 48;
}
IN int fpow(R int base, R int tim)
{
	int ret = 1;
	W (tim)
	{
		if (tim & 1) ret = 1ll * ret * base % MOD;
		base = 1ll * base * base % MOD, tim >>= 1;
	}
	return ret;
}
int n, m;
std::vector <int> dv[MX], A;
int a[MX], b[MX], c[MX], d[MX], e[MX], rev[MX], val[MX], GR[MX], DR[MX], FR[MX], buf[MX];
namespace Poly
{
	IN void NTT(int *dat, R int len, R int typ)
	{
		for (R int i = 1; i < len; ++i) if (rev[i] > i) std::swap(dat[rev[i]], dat[i]);
		R int seg, step, bd, now, cur, buf1, buf2, base, deal;
		for (seg = 1; seg < len; seg <<= 1)
		{
			base = fpow(typ ? g : Ginv, (MOD - 1) / (seg << 1)), step = seg << 1;
			for (now = 0; now < len; now += step)
			{
				bd = now + seg, deal = 1;
				for (cur = now; cur < bd; ++cur, deal = 1ll * deal * base % MOD)
				{
					buf1 = dat[cur], buf2 = 1ll * dat[cur + seg] * deal % MOD;
					dat[cur] = (buf1 + buf2) % MOD, dat[cur + seg] = (buf1 - buf2 + MOD) % MOD;
				}
			}
		}
		if (typ) return; int inv = fpow(len, MOD - 2);
		for (R int i = 0; i < len; ++i) dat[i] = 1ll * dat[i] * inv % MOD;
	}
	void CDQ(R int now, R int lef, R int rig)
	{
		if (lef == rig)
		{
			dv[now].push_back((MOD - val[lef] + MOD) % MOD);
			dv[now].push_back(1);
			return;
		}
		int mid = lef + rig >> 1, ls = now << 1, rs = now << 1 | 1;
		CDQ(ls, lef, mid), CDQ(rs, mid + 1, rig);
		int up = rig - lef + 1, len = 1, lg = 0, lsiz = mid - lef + 1, rsiz = rig - mid;
		for (; len <= up; len <<= 1, lg++);
		for (R int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
		for (R int i = 0; i < len; ++i) a[i] = b[i] = 0;
		for (R int i = 0; i <= lsiz; ++i) a[i] = dv[ls][i];
		for (R int i = 0; i <= rsiz; ++i) b[i] = dv[rs][i];
		NTT(a, len, 1), NTT(b, len, 1);
		for (R int i = 0; i < len; ++i) a[i] = 1ll * a[i] * b[i] % MOD;
		NTT(a, len, 0);
		for (R int i = 0; i <= up; ++i) dv[now].push_back(a[i]);
	}
	void Getinv(R int up, R int lg, int *ind, int *ans)
	{
		if (up == 1) return ans[0] = fpow(ind[0], MOD - 2), void();
		Getinv(up >> 1, lg - 1, ind, ans); R int len = up << 1;
		for (R int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
		for (R int i = 0; i < len; ++i) a[i] = b[i] = 0;
		for (R int i = 0; i < up; ++i) a[i] = ans[i], b[i] = ind[i];
		NTT(a, len, 1), NTT(b, len, 1);
		for (R int i = 0; i < len; ++i) a[i] = (2 * a[i] % MOD - 1ll * b[i] * a[i] % MOD * a[i] % MOD + MOD) % MOD;
		NTT(a, len, 0);
		for (R int i = 0; i < up; ++i) ans[i] = a[i];
	}
	IN void Div(int *G, int *Divs, R int n, R int m, int *sur)
	{
		for (R int i = 0; i <= n; ++i) sur[i] = GR[i] = DR[i] = 0;
		for (R int i = 0; i <= n; ++i) GR[i] = G[n - i];
		for (R int i = 0; i <= m; ++i) DR[i] = Divs[m - i];
		int len = 1, lg = 0, bd = n - m + 1;
		for (; len <= bd; len <<= 1, lg++);
		for (R int i = 0; i < len; ++i) buf[i] = 0;
		Getinv(len, lg + 1, DR, buf); len <<= 1, lg++;
		for (R int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
		for (R int i = 0; i < len; ++i) a[i] = b[i] = 0;
		for (R int i = 0; i < bd; ++i) a[i] = buf[i];
		for (R int i = 0; i < bd; ++i) b[i] = GR[i];
		NTT(a, len, 1), NTT(b, len, 1);
		for (R int i = 0; i < len; ++i) a[i] = 1ll * a[i] * b[i] % MOD;
		NTT(a, len, 0);
		for (len = 1, lg = 0; len <= n; len <<= 1, lg++);
		for (R int i = 1; i < len; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << lg - 1);
		for (R int i = 0; i < bd; ++i) buf[i] = a[bd - i - 1];
		for (R int i = 0; i < len; ++i) a[i] = b[i] = 0;
		for (R int i = 0; i <= m; ++i) a[i] = Divs[i];
		for (R int i = 0; i < bd; ++i) b[i] = buf[i];
		NTT(a, len, 1), NTT(b, len, 1);
		for (R int i = 0; i < len; ++i) a[i] = 1ll * a[i] * b[i] % MOD;
		NTT(a, len, 0);
		for (R int i = 0; i < m; ++i) sur[i] = (G[i] - a[i] + MOD) % MOD;
	}
	IN void solve(R int now, R int lef, R int rig, const std::vector <int> &dat, R int up)
	{
		if (lef == rig)
		{
			for (R int i = 0; i <= up; ++i) c[i] = dat[i];
			for (R int i = 0; i <= 1; ++i) d[i] = dv[now][i];
			Div(c, d, up, 1, e);
			printf("%d\n", e[0]);
			return;
		}
		std::vector <int> l, r; l.clear(), r.clear(); int len = 1, lg = 0;
		int mid = lef + rig >> 1, ls = now << 1, rs = now << 1 | 1;
		int lsiz = mid - lef + 1, rsiz = rig - mid;
		for (R int i = 0; i <= up; ++i) c[i] = dat[i];
		for (R int i = 0; i <= lsiz; ++i) d[i] = dv[ls][i];
		Div(c, d, up, lsiz, e);
		for (R int i = 0; i < lsiz; ++i) l.push_back(e[i]);
		for (R int i = 0; i <= rsiz; ++i) d[i] = dv[rs][i];
		Div(c, d, up, rsiz, e);
		for (R int i = 0; i < rsiz; ++i) r.push_back(e[i]);
		solve(ls, lef, mid, l, lsiz - 1);
		solve(rs, mid + 1, rig, r, rsiz - 1);
	}
}
int main(void)
{
	in(n), in(m); int foo;
	for (R int i = 0; i <= n; ++i) in(foo), foo %= MOD, A.push_back(foo);
	for (R int i = 1; i <= m; ++i) in(val[i]);
	Poly::CDQ(1, 1, m);
	Poly::solve(1, 1, m, A, n);
}

猜你喜欢

转载自blog.csdn.net/LPA20020220/article/details/85240439