TJOI2013 单词

题目链接:戳我

AC自动机qwqwqwq

给定很多个模式串。问模式串最多在这些模式串的集合中出现多少次?

这道题读题有三个注意点:一是模式串只能在模式串中匹配,最多出现多少次是在所有模式串中出现次数的总和。二是这些只要出现的位置集合不一样即视作又出现了一次。三是给出的模式串会有重复,并不是两两不同的。

第二点是常识,不必多说。第一点如何处理?我们把模式串拼接起来,拼成文本串。但是由于不能跨原先不同的模式串匹配,所以我们要在连接处添加一个永远不会出现的字符。(而且考虑到数组下标问题,这个东西在ASCII里面应该比z大)

对于第三点,我们或许能想起来曾经做过luogu上的AC自动机【模板2】,和这个比较类似。但是那个和这个又有所不同,这个模式串是会有重复的。所以统计方法不能相同,我们应该新开一个数组same,来记录当前这个串和前面哪个串相同(当然,如果是第一次出现就记录自己即可)。我们依然考虑ac_query的时候暴力向上跳fail去累加答案,但是字符串长度大于1e6,暴力跳fail很有可能会超时。所以我们考虑一个小小的优化,把每一次都跳fail的过程去掉,用一个前缀和数组一样的东西把答案累加起来。(这里的前缀和指的是,当前节点为根,所有指向它的fail边的反向边 指向的节点 所记录的值 的和)

然后因为get_fail的时候访问节点是按照dfs序来的,然后一个节点的fail边连向的点一定比它先访问到。所以到最后累加答案的时候我们直接从后往前即可。

具体情况看代码吧qwqwqwqwq

代码如下:

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<queue>
#define MAXN 2000010
using namespace std;
int n,cnt,l,tot;
int same[210],dfn[MAXN],ans[210];
char s[MAXN],p[MAXN];
struct Node{int t[26],fail,id,sum;}ac[MAXN];
inline void build(char x[],int k)
{
    int len=strlen(x+1),now=0;
    for(int i=1;i<=len;i++)
    {
        if(ac[now].t[x[i]-'a']==0)
            ac[now].t[x[i]-'a']=++cnt;
        now=ac[now].t[x[i]-'a'];
    }
    if(ac[now].id==0) ac[now].id=k,same[k]=k;
    else same[k]=ac[now].id;
}
inline void get_fail()
{
    queue<int>q;
    for(int i=0;i<=25;i++)
        if(ac[0].t[i]!=0)
            ac[ac[0].t[i]].fail=0,q.push(ac[0].t[i]),dfn[++tot]=ac[0].t[i];
    while(!q.empty())
    {
        int u=q.front();q.pop();
        for(int i=0;i<=25;i++)
        {
            if(ac[u].t[i]!=0) 
            {
                ac[ac[u].t[i]].fail=ac[ac[u].fail].t[i];
                q.push(ac[u].t[i]);
                dfn[++tot]=ac[u].t[i];
            }
            else ac[u].t[i]=ac[ac[u].fail].t[i];
        }
    }
}
inline void ac_query(char x[])
{
    int len=strlen(x+1),now=0;
    for(int i=1;i<=len;i++)
    {
        if(x[i]=='~') now=0;
        else now=ac[now].t[x[i]-'a'];
        ac[now].sum++;
    }
}
int main()
{
    #ifndef ONLINE_JUDGE
    freopen("ce.in","r",stdin);
    #endif
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
    {
        scanf("%s",s+1);
        build(s,i);
        for(int j=1,limit=strlen(s+1);j<=limit;j++)
            p[++l]=s[j];
        p[++l]='~';
    }
    get_fail();
    ac_query(p);
    for(int i=tot;i>=1;i--)
    {
        int now=dfn[i];//now表示当前处理点
        ans[ac[now].id]+=ac[now].sum;
        ac[ac[now].fail].sum+=ac[now].sum;
    }
    for(int i=1;i<=n;i++) printf("%d\n",ans[same[i]]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/fengxunling/p/10393831.html