题面
解法
- 先考虑一个最简单的dp:令 表示前 个数的乘积对 取模为 的方案数,转移比较简单,在这里就不写了。
- 但是我们会发现,转移的时候是乘法,并没有特别好的优化方式。
- 注意 是一个质数,一定存在原根 。那么,我们就可以用 的若干次方表示出 中的所有数。
- 现在我们不妨对原来的状态稍作修改, 表示为前 个数的乘积对 取模与 同余,然后转移就是
- 假设当前的 为 , 为 ,那么 ,所以最后的答案 。
- 因为卷积满足结合律,所以我们可以对 进行快速幂,在做乘法的用NTT加速。
- 时间复杂度:
【注意事项】
- 原根可以暴力找出,并不会存在很大的原根。
- 在乘法的时候需要注意,中途的结果长度可能大于 ,那么对于 , ,因为 , 为原根,所以最小循环节为 ,那么答案也要相应地加上去。
代码
#include <bits/stdc++.h>
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void chkmin(T &x, T y) {x = x > y ? y : x;}
template <typename T> void read(T &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
const int N = 20010, Mod = 1004535809;
int m, a[N], b[N], c[N], f[N], g[N], num[N], rev[N];
bool check(int n, int g) {
for (int i = 1, cur = g; i < n - 1; i++, cur = cur * g % n)
if (cur == 1) return false;
return true;
}
int calc(int n) {for (int i = 2; ; i++) if (check(n, i)) return i;}
int Pow(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = 1ll * x * x % Mod)
if (y & 1) ret = 1ll * ret * x % Mod;
return ret;
}
void getrev(int l) {
for (int i = 0; i < (1 << l); i++)
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << l - 1);
}
void NTT(int *a, int n, int fl) {
for (int i = 0; i < n; i++)
if (rev[i] < i) swap(a[i], a[rev[i]]);
for (int i = 1; i < n; i <<= 1) {
int wn = Pow(3, fl == 1 ? (Mod - 1) / (i << 1) : Mod - 1 - (Mod - 1) / (i << 1));
for (int j = 0, r = i << 1; j < n; j += r) {
int w = 1;
for (int k = 0; k < i; k++, w = 1ll * w * wn % Mod) {
int tx = a[j + k], ty = 1ll * w * a[i + j + k] % Mod;
a[j + k] = (tx + ty) % Mod, a[i + j + k] = (tx - ty + Mod) % Mod;
}
}
}
if (fl == -1) {
int tmp = Pow(n, Mod - 2);
for (int i = 0; i < n; i++) a[i] = 1ll * a[i] * tmp % Mod;
}
}
void mul(int *a, int *tx, int *ty, int n) {
for (int i = 0; i < n; i++) b[i] = tx[i], c[i] = ty[i];
NTT(b, n, 1), NTT(c, n, 1);
for (int i = 0; i < n; i++) a[i] = 1ll * b[i] * c[i] % Mod;
NTT(a, n, -1);
for (int i = m - 1; i < n; i++) a[i - m + 1] = (a[i - m + 1] + a[i]) % Mod, a[i] = 0;
}
int calc(int n, int len, int x) {
g[0] = 1;
while (n) {
if (n & 1) mul(g, g, f, len);
n >>= 1, mul(f, f, f, len);
}
return g[num[x]];
}
int main() {
int n, tx, s;
read(n), read(m), read(tx), read(s);
int t = calc(m);
for (int i = 1, cur = t; i < m - 1; i++, cur = cur * t % m) num[cur] = i;
for (int i = 1; i <= s; i++) {
int x; read(x);
if (x) f[num[x]]++;
}
int l = 0, len = 1;
while (len <= 2 * m) l++, len <<= 1; getrev(l);
cout << calc(n, len, tx) << "\n";
return 0;
}