CF1202E You Are Given Some Strings...(AC自动机)

看到多串匹配应该是AC自动机。

考虑所求式子,处理出所有$s_i$在$t$中的结尾位置。然后即查询有多少在那个位置后的$s_j$。

即统计位置$i$是多少串的结尾,位置$i+1$是多少串的开头。

思路很清晰了,只要正着反着跑一遍AC自动机应该就行了。

#include <bits/stdc++.h>
using namespace std;
const int N = 200010;
char t[N], s[N];
int ch1[N][26], fail1[N], cnt1[N], tot1 = 1;
int ch2[N][26], fail2[N], cnt2[N], tot2 = 1;
queue<int> q;
int n;
long long l[N], r[N];
long long ans;
int main() {
    scanf("%s", t + 1);
    scanf("%d", &n);
    for (int w = 1; w <= n; w++) {
        scanf("%s", s + 1);
        int p = 1;
        int len = strlen(s + 1);
        for (int i = 1; i <= len; i++) {
            if (!ch1[p][s[i] - 'a']) ch1[p][s[i] - 'a'] = ++tot1;
            p = ch1[p][s[i] - 'a'];
        }
        cnt1[p]++;
        p = 1;
        for (int i = len; i >= 1; i--) {
            if (!ch2[p][s[i] - 'a']) ch2[p][s[i] - 'a'] = ++tot2;
            p = ch2[p][s[i] - 'a'];
        }
        cnt2[p]++;
    }
    for (int i = 0; i < 26; i++) ch1[0][i] = 1;
    fail1[1] = 0;
    while (!q.empty()) q.pop();
    q.push(1);
    while (!q.empty()) {
        int x = q.front();
        q.pop();
        cnt1[x] += cnt1[fail1[x]];
        for (int i = 0; i < 26; i++) {
            if (ch1[x][i]) {
                fail1[ch1[x][i]] = ch1[fail1[x]][i];
                q.push(ch1[x][i]);
            } else {
                ch1[x][i] = ch1[fail1[x]][i];
            }
        }
    }
    for (int i = 0; i < 26; i++) ch2[0][i] = 1;
    fail2[1] = 0;
    while (!q.empty()) q.pop();
    q.push(1);
    while (!q.empty()) {
        int x = q.front();
        q.pop();
        cnt2[x] += cnt2[fail2[x]];
        for (int i = 0; i < 26; i++) {
            if (ch2[x][i]) {
                fail2[ch2[x][i]] = ch2[fail2[x]][i];
                q.push(ch2[x][i]);
            } else {
                ch2[x][i] = ch2[fail2[x]][i];
            }
        }
    }
    int len = strlen(t + 1);
    int p = 1;
    for (int i = 1; i <= len; i++) {
        p = ch1[p][t[i] - 'a'];
        l[i] = cnt1[p];
    }
    p = 1;
    for (int i = len; i >= 1; i--) {
        p = ch2[p][t[i] - 'a'];
        r[i] = cnt2[p];
    }
    for (int i = 1; i < len; i++) {
        ans += l[i] * r[i + 1];
    }
    printf("%lld", ans);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/zcr-blog/p/13166381.html