回文自动机/回文树(模板)

大犇文章
大犇文章
洛谷3649

还是封装了比较香

#include <iostream>
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <string>
#include <stack>
#include <queue>
#include <map>
#include <vector>

using namespace std;    

typedef unsigned long long ull;
typedef long long ll;

const int MAXN = 300000 + 5;
const int Mod = 1e9 + 7;
const int N = 26;

struct Palindromic_Tree
{
    
    
    int nxt[MAXN][N] , fail[MAXN], len[MAXN];
    int num[MAXN] , cnt[MAXN], s[MAXN];
    //cnt[i]表示节点i表示的本质不同的串的个数(建树时求出的不是完全的,最后count()函数跑一遍以后才是正确的)
    //num[i]表示以节点i表示的最长回文串的最右端点为回文串结尾的回文串个数。
    int last , n ,p;
    int newnode(int l)
    {
    
    
        for(int i = 0; i < N;  i++) nxt[p][i] = 0;
        cnt[p] = num[p]  = 0;
        len[p] = l;
        return p++;   
    }
    void init()
    {
    
    
        p = 0;
        newnode(0);
        newnode(-1);
        last = n = 0;
        s[n] = -1;
        fail[0] = 1;
    }
    int getfail(int x)
    {
    
    
        while(s[n - len[x] - 1] != s[n])    x = fail[x];
        return x;
    }
    void insert(int c)
    {
    
    
        c -= 'a';
        s[++n] = c;
        int cur = getfail(last);
        if(!nxt[cur][c]){
    
    
            int now = newnode(len[cur] + 2);
            fail[now] = nxt[ getfail( fail[cur] ) ][c];
            nxt[cur][c] = now;//记录这个节点标号
            num[now] = num[fail[now]] + 1;
        }
        last = nxt[cur][c];
        //如果是已经存在的回文串,就回到表示这个回文串的节点并++,也就是这个
        //本质不同回文串的出现次数+1(但并不是最终次数)
        cnt[last]++;
    }
    void Count()
    {
    
    
        for(int i = p - 1; i >= 0; i--) cnt[fail[i]] += cnt[i];
        //父亲累加儿子的cnt,因为如果fail[v]=u,则u一定是v的子回文串! 
        //而且不会重复计算,因为u是两端加上c等于v,而fail[v]是v的最长回文后缀,
        //出现位置不一样,但是在回文树中都是由用一个标记表示,所以直接累加就可以
        //得到标记   i   所表示的回文串出现的次数
    }
}pam;

char str[MAXN];

int main()
{
    
    
    pam.init();
    scanf("%s", str + 1);
    int len = strlen(str + 1);
    for(int i = 1; i <= len; i++)   pam.insert(str[i]);
    pam.Count();
    ll ans = 0;
    for(int i = 2; i <= pam.p - 1; i++) ans = max(ans, 1ll * pam.len[i]* pam.cnt[i]);
    printf("%lld\n", ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/CUCUC1/article/details/108890148
今日推荐