poj3376(manachar+字典树)

题解:两个字符串要拼凑成一个回文串

a串长度小于b串长度时候a反串是b的前缀,并且b剩余的后缀是回文串,要么就是a是b反串的前缀,并且b剩余的后缀是回文串,

当a串长度等于b时候,a反串是等于b。

那么我们先用manachar求出到达每个串的长度哪个位置能构成回文串,并且全部插入字典树中然后查询每个串跟其他串能够构成回文串的个数是多少

这样我们就求出所有|a|<=|b|的满足是b的前缀的串,接着我们再将所有的串改成其对应的反串,然后再跑马拉车,然后再统计一下一些a是b的反串的前缀个数,那么这样会出现一个问题,就是我可能重复统计,比如三个那么会重复统计的情况只有一种那就是a的反串等于b会重复统计,因为只有这种情况下,我在第一步统计a反串是b的前缀的时候把相等情况一起讨论了,然后我会出现说原来

ab ba,这种方案下ba,ab这样的情况重复统计所以我们用一另外一个点标记一下这个结点有多少串的结尾第二次统计的时候减去即可

#include<iostream>
#include<cstring>
#include<algorithm>
#include<queue>
#include<string>
#include<vector>
#include<cstdio>
#include<cmath>
#include<set>
#include<map>
#include<cstdlib>
#include<ctime>
#include<stack>
#include<bitset>
using namespace std;
#define mes(a,b) memset(a,b,sizeof(a))
#define rep(i,a,b) for(int i = a; i <= b; i++)
#define dec(i,a,b) for(int i = b; i >= a; i--)
#define fi first
#define se second
#define ls rt<<1
#define rs rt<<1|1
#define lson ls,L,mid
#define rson rs,mid+1,R
#define lowbit(x) x&(-x)
typedef double db;
typedef long long int ll;
typedef pair<int,int> pii;
typedef unsigned long long ull;
const ll inf = 0x3f3f3f3f;
const int mx = 2e6+5;
const int mod = 1e9+7;
const int x_move[] = {1,-1,0,0,1,1,-1,-1};
const int y_move[] = {0,0,1,-1,1,-1,1,-1};
int n,m;
ll ans;
char str[mx];
char b[mx*2];
int p[mx*2];
int w[mx];
int c[mx];
int len[mx];
int start[mx];
struct node{
    int ch[mx][26];
    int val[mx];
    int cnt[mx];
    int sz;
    void init(){
        sz = 1;
        mes(ch[0],0);
        val[0] = 0;
        cnt[0] = 0;
    }
    void insert(char str[],int len,int w[]){
        int u = 0;
        for(int i = 1; i <= len; i++){
            int d = str[i]-'a';
            if(!ch[u][d]){
                mes(ch[sz],0);
                val[sz] = 0;
                cnt[sz] = 0;
                ch[u][d] = sz++;
            }
            u = ch[u][d];
            val[u] += w[i];
        }
        cnt[u]++;
    }
    int find(char str[]){
        int u = 0;
        for(int i = 0; str[i]; i++){
            int d = str[i]-'a';
            if(!ch[u][d])
                return 0;
            u  = ch[u][d];
        }
        return val[u];
    }
    int getre(char str[]){
        int u = 0;
        for(int i = 0; str[i]; i++){
            int d = str[i]-'a';
            if(!ch[u][d])
                return 0;
            u = ch[u][d];
        }
        return cnt[u];
    }
}word;
void manachar(char str[],int len){
    m = 0;
    b[m] = '$';
    for(int i = 1; i <= len; i++)
        b[++m] = '#',b[++m] = str[i],w[i] = 0;
    b[++m] = '#';
    b[m+1] = 0;
    int po = 0,mx = 0;
    for(int i = 1; i <= m; i++){
        if(mx>i) p[i] = min(p[2*po-i],mx-i);
        else p[i] = 1;
        while(b[i-p[i]]==b[i+p[i]]) p[i]++;
        if(i+p[i]>mx)
            po = i,mx = i+p[i];
        if(i+p[i]>m)
            w[(i-p[i])/2] = 1;
    }
    word.insert(str,len,w);
}
void search(int be,int en,int flag){
    int m = 0;
    for(int i = en; i >= be; i--)
        b[m++] = str[i];
    b[m] = 0;
    ans += word.find(b);
    if(flag)
        ans -= word.getre(b);
}
void re(int be,int en){
    while(be<=en){
        swap(str[be],str[en]);
        be++;
        en--;
    }
}
void solve(){
    word.init();
    for(int i = 1; i <= n; i++){
        re(start[i],start[i]+len[i]-1);
        manachar(str+start[i]-1,len[i]);
    }
    for(int i = 1; i <= n; i++)
        search(start[i],start[i]+len[i]-1,1);

}
int main(){
    int t,q,ca = 1;
    while(scanf("%d",&n)!=EOF){
    int l = 1;
    word.init();
    ans = 0;
    for(int i = 1; i <= n; i++){
        scanf("%d%s",&len[i],str+l);
        start[i] = l;
        manachar(str+l-1,len[i]);
        l += len[i];
    }
    for(int i = 1; i <= n; i++)
        search(start[i],start[i]+len[i]-1,0);
    solve();
    printf("%lld\n",ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/a1325136367/article/details/81093950