题目大意:给出 n 个单词,问有多少个长度为 m 的字符串包含 n 个单词中的其中一个
JSOI2007的同名题目的数据范围是 n<=60, m<=100,这样就可以直接在AC自动机上dfs,搜到单词词尾答案加上26的剩余位数次方,否则继续搜下去。
把数据范围改成 n<=10, m<=1000000,且n个单词每个单词长度不超过 6,怎么做?
显然不能dfs了,数据范围限制最多有60个结点,60*60符合矩乘的空间,60*60*60也符合时间,考虑矩阵元素 (x, y) 是由第 x 行和第 y 列相乘,例如长宽为 2 的矩阵中 (1, 2) = (1, 1) * (1, 2) + (1, 2) * (2, 2),相当于一个长度为 3 起点为 1 终点为 2 的串,可以是 1->1->2,也可以 1->2->2。设mat[ i ][ j ]表示从 i 跳到 j 的方案数。
利用容斥原理,答案为总数 - 不出现任何单词的字符串个数,初始化时,若 val[ i ] == val[ j ] == 0 则 ++mat[ i ][ j ]。
for (int i = 0; i <= sz; ++i)
for (int j = 0; j < 26; ++j) if (val[i] == 0 && val[ch[i][j]] == 0) ++a.mat[i][ch[i][j]];
还需注意在 get_fail 的时候,val[ i ]应该 | 上 val[ fail[ i ] ]。
#include <cstdio> #include <cstring> #include <queue> const int P = 10007; char s[16]; int ch[100][26], sz, val[100], fail[100]; struct Mat { int mat[100][100]; void mult(Mat a, Mat b) { for (int i = 0; i <= sz; ++i) for (int j = 0; j <= sz; ++j) { mat[i][j] = 0; for (int k = 0; k <= sz; ++k) { mat[i][j] += a.mat[i][k] * b.mat[k][j]; if (mat[i][j] >= P) mat[i][j] %= P; } } } } a, unit; void insert(char *s) { int cur = 0, n = strlen(s); for (int i = 0; i < n; ++i) { int c = s[i] - 'A'; if (!ch[cur][c]) ch[cur][c] = ++sz; cur = ch[cur][c]; } val[cur] = 1; } void get_fail() { std::queue<int> Q; for (int c = 0; c < 26; ++c) if (ch[0][c]) Q.push(ch[0][c]); while (!Q.empty()) { int cur = Q.front(); Q.pop(); for (int c = 0; c < 26; ++c) { if (!ch[cur][c]) ch[cur][c] = ch[fail[cur]][c]; else { fail[ch[cur][c]] = ch[fail[cur]][c]; val[cur] |= val[fail[cur]]; Q.push(ch[cur][c]); } } } } Mat fast_pow(Mat a, int p) { Mat res = unit; for (;p; p >>= 1, a.mult(a, a)) { if (p & 1) res.mult(res, a); } return res; } int fast_pow(int a, int p) { int res = 1; for (;p; p >>= 1, a *= a) { if (a >= P) a %= P; if (p & 1) { res *= a; if (res >= P) res %= P; } } return res; } int main() { int n, m, ans = 0; scanf("%d%d", &n, &m); for (int i = 1; i <= n; ++i) scanf("%s", s), insert(s); get_fail(); for (int i = 0; i <= sz; ++i) { unit.mat[i][i] = 1; for (int j = 0; j < 26; ++j) if (val[i] == 0 && val[ch[i][j]] == 0) ++a.mat[i][ch[i][j]]; } a = fast_pow(a, m); for (int i = 0; i <= sz; ++i) { ans += a.mat[0][i]; if (ans >= P) ans %= P; } printf("%d\n", (fast_pow(26, m) - ans + P) % P); return 0; }