bzoj 3992 [SDOI2015]序列统计 NTT

题面

题目传送门

解法

  • 先考虑一个最简单的dp:令 f [ i ] [ j ] f[i][j] 表示前 i i 个数的乘积对 m m 取模为 j j 的方案数,转移比较简单,在这里就不写了。
  • 但是我们会发现,转移的时候是乘法,并没有特别好的优化方式。
  • 注意 m m 是一个质数,一定存在原根 g g 。那么,我们就可以用 g g 的若干次方表示出 [ 1 , m ) [1,m) 中的所有数。
  • 现在我们不妨对原来的状态稍作修改, f [ i ] [ j ] f[i][j] 表示为前 i i 个数的乘积对 m m 取模与 g j g^j 同余,然后转移就是 f [ i ] [ j ] = k f [ i 1 ] [ j k ] × s u m [ k ] f[i][j]=\sum_{k}f[i-1][j-k]\times sum[k]
  • 假设当前的 f [ i ] f[i] a a f [ i 1 ] f[i-1] b b ,那么 a = b s u m a=b*sum ,所以最后的答案 f [ n ] = f [ 0 ] s u m n f[n]=f[0]*sum^n
  • 因为卷积满足结合律,所以我们可以对 s u m n sum^n 进行快速幂,在做乘法的用NTT加速。
  • 时间复杂度: O ( m log n log m ) O(m\log n\log m)

【注意事项】

  • 原根可以暴力找出,并不会存在很大的原根。
  • 在乘法的时候需要注意,中途的结果长度可能大于 m m ,那么对于 i [ m 1 , 2 m ) i\in [m-1,2m) a [ i m + 1 ] + = a [ i ] a[i-m+1]+=a[i] ,因为 φ ( m ) = m 1 \varphi(m)=m-1 g g 为原根,所以最小循环节为 m 1 m-1 ,那么答案也要相应地加上去。

代码

#include <bits/stdc++.h>
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x > y ? y : x;}
template <typename T> void read(T &x) {
	x = 0; int f = 1; char c = getchar();
	while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
	while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
const int N = 20010, Mod = 1004535809;
int m, a[N], b[N], c[N], f[N], g[N], num[N], rev[N];
bool check(int n, int g) {
	for (int i = 1, cur = g; i < n - 1; i++, cur = cur * g % n)
		if (cur == 1) return false;
	return true;
}
int calc(int n) {for (int i = 2; ; i++) if (check(n, i)) return i;}
int Pow(int x, int y) {
	int ret = 1;
	for (; y; y >>= 1, x = 1ll * x * x % Mod)
		if (y & 1) ret = 1ll * ret * x % Mod;
	return ret;
}
void getrev(int l) {
	for (int i = 0; i < (1 << l); i++)
		rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
}
void NTT(int *a, int n, int fl) {
	for (int i = 0; i < n; i++)
		if (rev[i] < i) swap(a[i], a[rev[i]]);
	for (int i = 1; i < n; i <<= 1) {
		int wn = Pow(3, fl == 1 ? (Mod - 1) / (i << 1) : Mod - 1 - (Mod - 1) / (i << 1));
		for (int j = 0, r = i << 1; j < n; j += r) {
			int w = 1;
			for (int k = 0; k < i; k++, w = 1ll * w * wn % Mod) {
				int tx = a[j + k], ty = 1ll * w * a[i + j + k] % Mod;
				a[j + k] = (tx + ty) % Mod, a[i + j + k] = (tx - ty + Mod) % Mod;
			}
		}
	}
	if (fl == -1) {
		int tmp = Pow(n, Mod - 2);
		for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * tmp % Mod;
	}
}
void mul(int *a, int *tx, int *ty, int n) {
	for (int i = 0; i < n; i++) b[i] = tx[i], c[i] = ty[i];
	NTT(b, n, 1), NTT(c, n, 1);
	for (int i = 0; i < n; i++) a[i] = 1ll * b[i] * c[i] % Mod;
	NTT(a, n, -1);
	for (int i = m - 1; i < n; i++) a[i - m + 1] = (a[i - m + 1] + a[i]) % Mod, a[i] = 0;
}
int calc(int n, int len, int x) {
	g[0] = 1;
	while (n) {
		if (n & 1) mul(g, g, f, len);
		n >>= 1, mul(f, f, f, len);
	}
	return g[num[x]];
}
int main() {
	int n, tx, s;
	read(n), read(m), read(tx), read(s);
	int t = calc(m);
	for (int i = 1, cur = t; i < m - 1; i++, cur = cur * t % m) num[cur] = i;
	for (int i = 1; i <= s; i++) {
		int x; read(x);
		if (x) f[num[x]]++;
	}
	int l = 0, len = 1;
	while (len <= 2 * m) l++, len <<= 1; getrev(l);
	cout << calc(n, len, tx) << "\n";
	return 0;
}

猜你喜欢

转载自blog.csdn.net/emmmmmmmmm/article/details/86762344
今日推荐