COGS 1468 文本生成器

题目大意:给出 n 个单词,问有多少个长度为 m 的字符串包含 n 个单词中的其中一个

  JSOI2007的同名题目的数据范围是 n<=60, m<=100,这样就可以直接在AC自动机上dfs,搜到单词词尾答案加上26的剩余位数次方,否则继续搜下去。

  把数据范围改成 n<=10, m<=1000000,且n个单词每个单词长度不超过 6,怎么做?

  显然不能dfs了,数据范围限制最多有60个结点,60*60符合矩乘的空间,60*60*60也符合时间,考虑矩阵元素 (x, y) 是由第 x 行和第 y 列相乘,例如长宽为 2 的矩阵中 (1, 2) = (1, 1) * (1, 2) + (1, 2) * (2, 2),相当于一个长度为 3 起点为 1 终点为 2 的串,可以是 1->1->2,也可以 1->2->2。设mat[ i ][ j ]表示从 i 跳到 j 的方案数。

  利用容斥原理答案为总数 - 不出现任何单词的字符串个数,初始化时,若 val[ i ] == val[ j ] == 0 则 ++mat[ i ][ j ]。

for (int i = 0; i <= sz; ++i)
  
for (int j = 0; j < 26; ++j) if (val[i] == 0 && val[ch[i][j]] == 0) ++a.mat[i][ch[i][j]];

   还需注意在 get_fail 的时候,val[ i ]应该 | 上 val[ fail[ i ] ]。

#include <cstdio>
#include <cstring>
#include <queue>

const int P = 10007;
char s[16];
int ch[100][26], sz, val[100], fail[100];

struct Mat {
    int mat[100][100];
    void mult(Mat a, Mat b) {
        for (int i = 0; i <= sz; ++i)
            for (int j = 0; j <= sz; ++j) {
                mat[i][j] = 0;
                for (int k = 0; k <= sz; ++k) {
                    mat[i][j] += a.mat[i][k] * b.mat[k][j];
                    if (mat[i][j] >= P) mat[i][j] %= P;
                }
            }
    }
} a, unit;

void insert(char *s) {
    int cur = 0, n = strlen(s);
    for (int i = 0; i < n; ++i) {
        int c = s[i] - 'A';
        if (!ch[cur][c]) ch[cur][c] = ++sz;
        cur = ch[cur][c];
    }
    val[cur] = 1;
}
void get_fail() {
    std::queue<int> Q;
    for (int c = 0; c < 26; ++c)
        if (ch[0][c]) Q.push(ch[0][c]);
    while (!Q.empty()) {
        int cur = Q.front(); Q.pop();
        for (int c = 0; c < 26; ++c) {
            if (!ch[cur][c]) ch[cur][c] = ch[fail[cur]][c];
            else {
                fail[ch[cur][c]] = ch[fail[cur]][c];
                val[cur] |= val[fail[cur]];
                Q.push(ch[cur][c]);
            }
        }
    }
}
Mat fast_pow(Mat a, int p) {
    Mat res = unit;
    for (;p; p >>= 1, a.mult(a, a)) {
        if (p & 1) res.mult(res, a);
    }
    return res;
}
int fast_pow(int a, int p) {
    int res = 1;
    for (;p; p >>= 1, a *= a) {
        if (a >= P) a %= P;
        if (p & 1) {
            res *= a;
            if (res >= P) res %= P;
        }
    }
    return res;
}

int main() {
    int n, m, ans = 0;
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%s", s), insert(s);
    get_fail();
    for (int i = 0; i <= sz; ++i) {
        unit.mat[i][i] = 1;
        for (int j = 0; j < 26; ++j)
            if (val[i] == 0 && val[ch[i][j]] == 0) ++a.mat[i][ch[i][j]];
    }
    a = fast_pow(a, m);
    for (int i = 0; i <= sz; ++i) {
        ans += a.mat[0][i];
        if (ans >= P) ans %= P;
    }
    printf("%d\n", (fast_pow(26, m) - ans + P) % P);
    return 0; 
}

猜你喜欢

转载自www.cnblogs.com/milky-w/p/9054607.html