回文树模板(求回文子串中不同字符数量)

Colorful String

The value of a string ss is equal to the number of different letters which appear in this string.

Your task is to calculate the total value of all the palindrome substring.

Input

The input consists of a single string ∣s∣(1≤∣s∣≤3×105).

The string ss only contains lowercase letters.

Output

Output an integer that denotes the answer.

样例输入

abac

样例输出

6

样例解释

abacabac has palindrome substrings a,b,a,c,aba,ans the total value is equal to 1+1+1+1+2=6。

链接:https://nanti.jisuanke.com/t/41389

题意:每个回文字串中,不同字符数量。

题解:回文树板子,在树的每个节点加上sum[maxn],cnt2[maxn][30],前者是每个节点不同字母计数,后者是该节点有哪些字母。

#include <iostream>
#include <cstdio>
#include <cmath>
#include <cstring>
#define inf 0x3f3f3f3f
using namespace std;
const int maxn=1e6+5;

struct PAM{
	int sum[maxn];     //每个节点不同字母数量(模板添加)
	int cnt2[maxn][30];//每个节点字母统计(模板添加)
	int nex[maxn][30];//指向的串为当前串两端加上同一个字符构成
	int fail[maxn];//fail跳转到自己这个串的最长回文后缀 
	int cnt[maxn];//出现次数 
	int num[maxn];// 表示以节点i表示的最长回文串的最右端点为回文串结尾的回文串个数。
	int len[maxn];//len[i]表示节点i表示的回文串的长度 
	int S[maxn];//存放添加的字符
	int last,n,p;//last指向上一个字符所在的节点,方便下一次add
	
	int create(int rt){//新建节点 
		memset(nex[p],0,sizeof(nex[p]));
		cnt[p]=0;
		num[p]=0;
		len[p]=rt;
		return p++;
	}
	
	void init(){//初始化 
		p=last=n=0;
		create(0);
		create(-1);
		S[0]=-1;
		fail[0]=1;
	}
	
	int getFail(int x){//寻找失败节点 
		while(S[n-len[x]-1]!=S[n])	x=fail[x];
		return x;
	}
	void insert(int c)//插入字符 
	{
		c=c-'a';
		S[++n]=c;
		int cur=getFail(last);
		if(!nex[cur][c]){
			int now=create(len[cur]+2);
			fail[now]=nex[getFail(fail[cur])][c];
			nex[cur][c]=now;
        	//-------------------------------模板添加
        	sum[now] = sum[cur];
        	for(int j = 0; j < 26; j++) 
			{
				if(cnt2[cur][j] == 1)
				cnt2[now][j] = cnt2[cur][j];
			}
			if(cnt2[now][c] == 0){
				cnt2[now][c] = 1;
				sum[now]++;
			}    
		//-----------------------------------
			num[now]=num[fail[now]]+1;
		}
		last=nex[cur][c];
		cnt[last]++;
	}
	
	void count()//cnt答案不准确,需要调用更新一下。 
	{
		long long ans=0;
		for (int i = p-1; i >= 0; i--)
        	cnt[ fail[i] ] += cnt[i];
	}
	
	void set(string s){//设置字符串 
		init();
		int len=s.size();
		for(int i=0;i<len;i++)
			insert(s[i]);
	}
}pam;

int main()
{
	string s;
	cin>>s;
	pam.set(s);
	pam.count();
	long long ans=0;
	for(int i=2;i<=pam.p-1;i++)
	{
		ans+=(pam.sum[i])*pam.cnt[i];
	}
	cout<<ans<<endl;
	return 0;
}
发布了39 篇原创文章 · 获赞 27 · 访问量 4127

猜你喜欢

转载自blog.csdn.net/qq_43381887/article/details/100853252
今日推荐