回文树上dfs——牛客多校第六场C

/*
set里的一定是本质不同的回文串,所以先建立回文树
当a可以通过nxt指针到达b,或者b可以通过fail指针到达a时,a就是b的子串
对于回文树里的每个结点u,我们可以将和其有关的结点为两部分:
    1.结点下方的子树,这部分的所有结点都可以由u在两边加点得到,设大小为 size[u] 
    2.结点向上的fail链,这条链上的所有结点都是u的回文后缀,设大小为 tot[u]
那么所有fail链上的点都是u的子树的子串,所以u的贡献为size[u]*tot[u] 
然后还要去重:对于u的子孙v, v的fail链可能会与u重合,重合部分的贡献在u处已经算过,那么显然v处就不用再算一次
所以dfs时要用vis标记被访问过的fail点,推出递归前回溯即可 
*/
#include<bits/stdc++.h>
using namespace std;
#define maxn 100005
struct PAM{
    int nxt[maxn][26],len[maxn],fail[maxn];
    int num[maxn],cnt[maxn];
    int S[maxn],n,p,last;
    int newnode(int l){
        memset(nxt[p],0,sizeof nxt[p]);
        len[p]=l;
        num[p]=cnt[p]=0;
        return p++;
    }
    void init(){
        p=0;
        newnode(0);
        newnode(-1);
        fail[0]=1;
        last=n=0;
        S[0]=-1;
    }
    int get_fail(int x){
        while(S[n-len[x]-1]!=S[n])x=fail[x];
        return x;
    }
    void add(int c){
        c-='a';S[++n]=c;
        int cur=get_fail(last);
        if(!nxt[cur][c]){
            int now=newnode(len[cur]+2);
            fail[now]=nxt[get_fail(fail[cur])][c];
            nxt[cur][c]=now;
            num[now]=num[fail[now]]+1;
        }
        last=nxt[cur][c];
        cnt[last]++;
    }
    int vis[maxn],size[maxn],tot[maxn];
    void dfs1(int u){
        size[u]=1;
        for(int i=0;i<26;i++)
            if(nxt[u][i]){
                int v=nxt[u][i];
                dfs1(v);
                size[u]+=size[v];
            }
    }
    void dfs2(int u){
        tot[u]=0;
        for(int x=u;!vis[x] && x>1;x=fail[x])
            tot[u]++,vis[x]=u;
        for(int i=0;i<26;i++)
            if(nxt[u][i]){
                int v=nxt[u][i];
                dfs2(v);
            }
        for(int x=u;vis[x]==u&&x>1;x=fail[x])    
            vis[x]=0;
    }
    long long count(){
        for(int i=p-1;i>=2;i--)
            cnt[fail[i]]+=cnt[i];
        dfs1(0);dfs2(0);
        dfs1(1);dfs2(1);
        long long res=0;
        for(int i=2;i<p;i++)
            res=res+size[i]*tot[i];
        return res-(p-2);    
    }
}tr; 
char s[maxn];

int main(){
    int t;cin>>t;
    for(int tt=1;tt<=t;tt++){
        scanf("%s",s);
        int len=strlen(s);
        tr.init();
        for(int i=0;i<len;i++)
            tr.add(s[i]);
        printf("Case #%d: %lld\n",tt,tr.count());
    } 
} 

猜你喜欢

转载自www.cnblogs.com/zsben991126/p/11329261.html