快12点了= =。
实在是忍不住想切这题,写个博客记录一下思路明天再搞(虽然这篇博客过了审核的时候我应该又开始搞了)。
题目简化后是这样的,给两个串s和t,求从s中取一个子串与t的一个前缀连接后能组成回文串的个数,但是要求从s中取的子串长度要大于所选择的t的前缀。
一开始漏看了后半段的条件…一想这不是sam的裸题吗,s串倒着建sam,然后t串进去走,走不动了就立马停,接着ringt集合大小搞一通。
然后一看样例不对…这样例一(s = ababa,t = aba)怎么是个5啊。
然后再一看要求取的子串必须大于取的前缀…
又一想不对啊…就算没有这个条件也要考虑中间的回文,有了这个条件反而简单了。
那正确的思路应该是,取的子串的左边倒过来与t中取的前缀相匹配,然后中间夹着一个回文串。
那么第一步毫无疑问是…处理出s串中所有回文串的位置。
怎么处理呢…manacher或者回文树吧,sam应该也可以直接处理出回文串,在2015年张天扬的国家集训队论文里见过这种操作。
不会,告辞.jpg,我只会用回文算法求回文。
嗯…之后呢,之后该怎么办…
匹配位置加前一个位置(倒着来的嘛)的回文串个数…
啊写不下去了写不下去了再写门禁了明天再搞。
2018.10.15日中午:
我写完了…
应该没问题吧…
思路是s串倒着建sam,加入每个节点的出现位置信息end,然后建出后缀树(还真是正牌后缀树,这里刚好反着建了)。
接着回文树也倒着处理出每一个位置的回文个数。
然后每走一步就进入后缀树去遍历每一个节点,每到一个就代表着当前t的前缀出现了一次,加上其下一个位置的回文个数即可。注意这个时候不考虑克隆的节点,遇到克隆的节点跳过直接向下走,因为克隆的节点必然会在之后的主轴上考虑过,在克隆时顺手将克隆节点的end标记为-1。
复杂度妥妥O(n)。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 1e6 + 5;
char s[maxn], t[maxn];
struct Sam {
int next[maxn << 1][26];
int link[maxn << 1], step[maxn << 1];
vector<int> G[maxn << 1];
int end[maxn << 1];
int sz, last, len, root;
void init() {
//如多次建立自动机,加入memset操作
root = sz = last = 1;
}
void add(int c, int id) {
int p = last;
int np = ++sz;
last = np;
end[np] = id;
step[np] = step[p] + 1;
while(!next[p][c] && p) {
next[p][c] = np;
p = link[p];
}
if(p == 0) {
link[np] = root;
} else {
int q = next[p][c];
if(step[p] + 1 == step[q]) {
link[np] = q;
} else {
int clone = ++sz;
memcpy(next[clone], next[q], sizeof(next[q]));
step[clone] = step[p] + 1;
link[clone] = link[q];
end[clone] = -1;
link[q] = link[np] = clone;
while(next[p][c] == q && p) {
next[p][c] = clone;
p = link[p];
}
}
}
}
void build() {
init();
len = strlen(s);
for(int i = len - 1; i >= 0; i--) {
add(s[i] - 'a', i);
}
for(int i = sz; i > 1; i--) {
G[link[i]].push_back(i);
}
// for (int i = 1; i <= sz; i++) {
// cout << i;
// for (int j = 0; j < G[i].size(); j++)
// cout << ' ' << G[i][j];
// cout << endl;
// }
}
} sam;
struct Pam {
int next[maxn][26];
int fail[maxn];
int len[maxn];// 当前节点表示回文串的长度
int num[maxn];// 到当前节点这里有多少本质不同的回文子串
int pa[maxn];
int S[maxn];
int last, n, p;
int newNode(int l) {
memset(next[p], 0, sizeof(next[p]));
len[p] = l;
num[p] = 0;
return p++;
}
void init() {
n = last = p = 0;
newNode(0);
newNode(-1);
S[n] = -1;
fail[0] = 1;
}
int getFail(int x) {
while(S[n - len[x] - 1] != S[n]) {
x = fail[x];
}
return x;
}
int add(int c) {
S[++n] = c;
int cur = getFail(last);
if(!next[cur][c]) {
int now = newNode(len[cur] + 2);
fail[now] = next[getFail(fail[cur])][c];
next[cur][c] = now;
num[now] = num[fail[now]] + 1;
}
last = next[cur][c];
return num[last];
}
void build() {
init();
int lenn = strlen(s);
for(int i = lenn - 1; i >= 0; i--) {
pa[i] = add(s[i] - 'a');
}
}
} pam;
/*
ababa
bc
aabbaa
aabb
abac
acba
*/
ll ans = 0;
void dfs(int x, int pre) {
int id = sam.end[x] + pre;
// printf("id = %d\n", id);
if(sam.end[x] != -1) {
ans += pam.pa[id];
}
// printf("%d %d\n", pam.pa[id], sam.end[x]);
for(int i = 0; i < sam.G[x].size(); i++) {
dfs(sam.G[x][i], pre);
}
}
void solve() {
sam.build();
pam.build();
int p = sam.root, len = strlen(t);
// for(int i = 0; i < sam.step[sam.last]; i++) {
// printf("%d ", pam.pa[i]);
// }
// printf("\n");
for(int i = 0, c; i < len; i++) {
c = t[i] - 'a';
if(!sam.next[p][c]) {
break;
}
p = sam.next[p][c];
dfs(p, i + 1);
// printf("\n");
}
printf("%lld\n", ans);
}
int main() {
scanf("%s%s", s, t);
solve();
return 0;
}