P4081 [USACO17DEC]Standing Out from the Herd

思路

对所有串建立广义SAM,之后记录SZ,统计本质不同子串时只统计SZ=1的即可

代码

#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
int maxlen[201000],minlen[201000],trans[201000][26],suflink[201000],ans[201000],sz[201000],last[201000],Nodecnt,n;
char s[201000];
int New_state(int _maxlen,int _minlen,int *_trans,int _sz,int _last,int _suflink){
    ++Nodecnt;
    maxlen[Nodecnt]=_maxlen;
    minlen[Nodecnt]=_minlen;
    if(_trans)
        for(int i=0;i<26;i++)
            trans[Nodecnt][i]=_trans[i];
    sz[Nodecnt]=_sz;
    last[Nodecnt]=_last;
    suflink[Nodecnt]=_suflink;
    return Nodecnt;    
}
void update(int u,int x){
    while(u&&last[u]!=x){
        last[u]=x;
        sz[u]++;
        u=suflink[u];        
    }
}
int add_len(int u,int c,int inq){
    if(trans[u][c]){
        int v=trans[u][c];
        if(maxlen[v]==maxlen[u]+1){
            update(v,inq);
            return v;
        }
        int y=New_state(maxlen[u]+1,0,trans[v],sz[v],last[v],suflink[v]);
        suflink[v]=y;
        minlen[v]=maxlen[y]+1;
        while(u&&trans[u][c]==v){
            trans[u][c]=y;
            u=suflink[u];
        }
        minlen[y]=maxlen[suflink[y]]+1;
        update(y,inq);
        return y;
    }
    else{
        int z=New_state(maxlen[u]+1,0,NULL,0,0,0);
        while(u&&trans[u][c]==0){
            trans[u][c]=z;
            u=suflink[u];
        }
        if(!u){
            suflink[z]=1;
            minlen[z]=1;
            update(z,inq);
            return z; 
        }
        int v=trans[u][c];
        if(maxlen[v]==maxlen[u]+1){
            suflink[z]=v;
            minlen[z]=maxlen[v]+1;
            update(z,inq);
            return z;
        }
        int y=New_state(maxlen[u]+1,0,trans[v],sz[v],last[v],suflink[v]);
        suflink[v]=suflink[z]=y;
        minlen[v]=minlen[z]=maxlen[y]+1;
        while(u&&(trans[u][c]==v)){
            trans[u][c]=y;
            u=suflink[u]; 
        }
        minlen[y]=maxlen[suflink[y]]+1;
        update(z,inq);
        return z;
    }
}
int main(){
    Nodecnt=1;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%s",s+1);
        int last=1,len=strlen(s+1);
        for(int j=1;j<=len;j++)
            last=add_len(last,s[j]-'a',i);
    }
    for(int i=2;i<=Nodecnt;i++){
        if(sz[i]<=1){
            ans[last[i]]+=maxlen[i]-minlen[i]+1;
        }
    }
    for(int i=1;i<=n;i++)
        printf("%d\n",ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/dreagonm/p/10721243.html