ac自动机暴力跳fail匹配——hdu5880

很简单的题,ac自动机里再维护一个len表示每个状态的串长,用s去query时每到一个结点都要暴力跳fail,因为有可能这个结点不是,但是其fail是危险结点,找到一个就直接break

再用个差分数组快速统计覆盖情况即可

using namespace std;
#define N 1000005

char s[N],t[N];
int n,cnt[N];

struct Trie{
    int nxt[N][26],fail[N],end[N],Len[N];
    int root,L;
    int newnode(){
        memset(nxt[L],-1,sizeof nxt[L]);
        end[L]=0;
        return L++;
    }
    void init(){
        L=0;
        root=newnode();
    }
    void insert(char buf[]){
        int len=strlen(buf);
        int now=root;
        for(int i=0;i<len;i++){
            if(nxt[now][buf[i]-'a']==-1)
                nxt[now][buf[i]-'a']=newnode();
            now=nxt[now][buf[i]-'a'];
        }
        end[now]++;Len[now]=len;
    }
    void build(){
        queue<int>q;
        fail[root]=root;
        for(int i=0;i<26;i++)
            if(nxt[root][i]==-1)
                nxt[root][i]=root;
            else {
                fail[nxt[root][i]]=root;
                q.push(nxt[root][i]);
            }
        while(q.size()){
            int now=q.front();
            q.pop(); 
            for(int i=0;i<26;i++)
                if(nxt[now][i]==-1)
                    nxt[now][i]=nxt[fail[now]][i];
                else {
                    fail[nxt[now][i]]=nxt[fail[now]][i];
                    q.push(nxt[now][i]);
                }
        }
    }
    
    void query(char *s){
        int now=root;
        int len=strlen(s);
        for(int i=0;i<len;i++){
            if(s[i]<'a' || s[i]>'z'){
                now=root;continue;
            }
            now=nxt[now][s[i]-'a'];
            int p=now;
            while(p){
                if(end[p]){//遇到危险结点了 
                    cnt[i+1]--;
                    cnt[i-Len[p]+1]++;
                    break;
                }
                p=fail[p];
            }
        }
    }
}ac;

int main(){
    int tt;cin>>tt;while(tt--){
        ac.init();
        cin>>n;
        for(int i=1;i<=n;i++){
            scanf("%s",s);
            ac.insert(s);
        }
        ac.build();
        
        char ch;
        int len=0;
        getchar();
        scanf("%[^\n]%*c",s);
        len=strlen(s);
        
        for(int i=0;i<len;i++){
            t[i]=s[i];
            if(s[i]>='A' && s[i]<='Z')
                s[i]+='a'-'A';
        }
        t[len]=0;
        
        for(int i=0;i<=len;i++)cnt[i]=0;
        ac.query(s);
        for(int i=1;i<len;i++)cnt[i]+=cnt[i-1];
        for(int i=0;i<len;i++){
            if(cnt[i]>=1)printf("*");
            else printf("%c",t[i]);
        }
        
        puts("");
    }
}

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/11706137.html