Codeforces 710F String Set Queries AC自动机+二进制分组

题意

有一个字符串集合D,要求资辞m个操作:
1 str表示在D中加入字符串str
2 str表示在D中删除字符串str
3 str表示询问D中每个字符串在str中的出现次数的和
强制在线
m , | s t r | 3 10 5

分析

一开始想到一个在线建广义sam+lct的做法,后来羊告诉我说在线建广义sam的复杂度是不对的,加上这个做法十分难打,就放弃了。
正解是AC自动机+二进制分组。
首先我们可以分别维护插入串和删除串,用插入的串的答案-删除的串的答案即可。
如果可以离线的话,显然我们可以用AC自动机来做,但由于fail链只能每次暴力重构,所以不能直接做。
考虑用二进制分组,也就是对新加入的一个串建AC自动机,然后如果最后两个AC自动机的size相同,就暴力合并然后重构fail指针。
查询就在每个AC自动机上分别查询即可。
显然每个时刻AC自动机的数量不超过 O ( l o g n ) 个,且每个字符串最多被合并 O ( l o g n ) 次,所以总的复杂度就是 O ( n l o g n ) ,其中 表示字符集大小。

代码

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#include<queue>

typedef long long LL; 

const int N=300005;

int m,len;
char str[N];
std::queue<int> que;

struct ACAM
{
    int sz,cnt[N],ch[N][26],ac[N][26],rt[25],sum[N],fail[N],t[25];

    void add(int &d,int x)
    {
        if (!d) d=++sz;
        if (x>len) {cnt[d]++;return;}
        add(ch[d][str[x]-'a'],x+1);
    }

    int merge(int x,int y)
    {
        if (!x||!y) return x^y;
        cnt[x]+=cnt[y];
        for (int i=0;i<26;i++) ch[x][i]=merge(ch[x][i],ch[y][i]);
        return x;
    }

    void build(int d)
    {
        que.push(d);fail[d]=d;sum[d]=0;
        for (int i=0;i<26;i++) ac[d][i]=d;
        while (!que.empty())
        {
            int u=que.front();que.pop();
            sum[u]=sum[fail[u]]+cnt[u];
            for (int i=0;i<26;i++)
                if (ch[u][i])
                {
                    int v=ch[u][i];
                    que.push(v);
                    fail[v]=ac[fail[u]][i]; 
                    ac[u][i]=v;
                }
                else ac[u][i]=ac[fail[u]][i];
        }
    }

    void ins()
    {
        add(rt[0],1);
        int i;
        for (i=0;t[i]==1;i++)
        {
            rt[i+1]=merge(rt[i],rt[i+1]);
            t[i]^=1;
            rt[i]=0;
        }
        t[i]^=1;
        build(rt[i]);
    }

    LL calc()
    {
        LL ans=0;
        for (int i=0;i<=20;i++)
        {
            if (!rt[i]) continue;
            int x=rt[i];
            for (int j=1;j<=len;j++)
            {
                x=ac[x][str[j]-'a'];
                ans+=sum[x];
            }
        }
        return ans;
    }
}ac1,ac2;

int main()
{
    scanf("%d",&m);
    while (m--)
    {
        int op;scanf("%d%s",&op,str+1);
        len=strlen(str+1);
        if (op==1) ac1.ins();
        else if (op==2) ac2.ins();
        else printf("%lld\n",ac1.calc()-ac2.calc()),fflush(stdout);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_33229466/article/details/80931380