Codeforces 1202 E You Are Given Some Strings... —— AC自动机

This way

题意:

给你一个模式串,和一些匹配串,问你任意两个匹配串连起来在模式串出现的次数中的总和是多少。

题解:

用一个正的ac自动机处理出模式串中每个位置匹配串的结尾数量,再用一个反的自动机处理出模式串中每个位置开头的数量即可。
这里好像要预处理ed数组,在build的时候用前缀和的思想将它加上去。要不然会T

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N=2e5+5,R=26;
struct Tire{
    int nxt[N][R],fail[N],ed[N],num[N];
    int rt,tot,cnt;
    int newnode(){
        for(int i=0;i<R;i++)nxt[tot][i]=-1;
        ed[tot]=0;
        return tot++;
    }
    void init(){
        memset(num,0,sizeof(num));
        tot=cnt=0;
        rt=newnode();
    }
    int insert(char *s){
        int now=rt,len=strlen(s);
        for(int i=0;i<len;i++){
            int val=s[i]-'a';
            if(nxt[now][val]==-1)nxt[now][val]=newnode();
            now=nxt[now][val];
        }
        ed[now]++;
        return now;
    }
    void build(){
        queue<int>q;
        fail[rt]=rt;
        for(int i=0;i<R;i++){
            if(nxt[rt][i]==-1)nxt[rt][i]=rt;
            else {
                fail[nxt[rt][i]]=rt;
                q.push(nxt[rt][i]);
            }
        }
        while(!q.empty()){

            int now=q.front();q.pop();
            ed[now]+=ed[fail[now]];
            for(int i=0;i<R;i++){
                if(nxt[now][i]==-1)nxt[now][i]=nxt[fail[now]][i];
                else {
                    fail[nxt[now][i]]=nxt[fail[now]][i];
                    q.push(nxt[now][i]);
                }
            }
        }
    }
    void cal(char *s)
    {
        int len=strlen(s),now=rt;
        for(int i=0;i<len;i++)
        {
            now=nxt[now][s[i]-'a'];
            num[i]=ed[now];
        }
    }
}ac1,ac2;
char s[N],ss[N];
int main()
{
    int n;
    scanf("%s",s);
    scanf("%d",&n);
    ac1.init(),ac2.init();
    for(int i=1;i<=n;i++)
    {
        scanf("%s",ss);
        ac1.insert(ss);
        reverse(ss,ss+strlen(ss));
        ac2.insert(ss);
    }
    ac1.build(),ac2.build();
    ac1.cal(s);
    reverse(s,s+strlen(s));
    ac2.cal(s);
    ll ans=0;
    int len=strlen(s);
    for(int i=0;i<len;i++)
        ans=ans+1ll*ac1.num[i]*ac2.num[len-i-1-1];
    printf("%lld\n",ans);
    return 0;
}

发布了530 篇原创文章 · 获赞 31 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/tianyizhicheng/article/details/102365983