https://vjudge.net/problem/POJ-3415
其实很早以前这道题就过了,但不过因为后缀数组的方法自己也不是很懂就没有写。
今天我学了一下后缀自动机,我们利用后缀自动机解决。
首先对A串建立后缀自动机,然后对每一个非克隆结点进行标记, 然后建一颗fail树,用一个dfs进行计数,这个步骤就是其他博客上写的拓扑排序。
-
首先对于一个结点,他所包含的子串个数为 ,那么包含长度大于等于k,那么就是 上述的结论以及推导是可以由后缀自动机的定义得到的。
-
然后我们用B串到自动机上进行匹配,并且记录匹配长度,这里可能有同学会犯错,我就犯错了,我认为匹配的状态的len数组和匹配的长度是一样的,因此就这样标记了,然后会多算。我们继续,如果当匹配长度大于等于k时,就可以 这里的d1数组就是上面的dfs计数得到的。然后如果是 大于等于k那么就要对d2进行计数了,因为你不管你是从那一个儿子上来的,只要父亲满足上面那个条件,总会找到一个满足的串。
-
最后就仿造上面的步骤,按着刚刚的那个方法在fail树上进行计数就可以了。
整体来看SAM的效率还是挺高的,但不过确实不好理解。
//#include "bits/stdc++.h"
#include <stdio.h>
#include <string.h>
#include <algorithm>
#include <math.h>
#include <iostream>
using namespace std;
//inline int read() {
// int x = 0;
// bool f = 1;
// char c = getchar();
// for (; !isdigit(c); c = getchar()) if (c == '-') f = 0;
// for (; isdigit(c); c = getchar()) x = (x << 3) + (x << 1) + c - '0';
// if (f) return x;
// return 0 - x;
//}
typedef long long ll;
const int maxn = 110000 + 10;
int getc(char c) {
if (c >= 'a' && c <= 'z') return c - 'a';
return c - 'A' + 26;
}
int len[maxn << 1], link[maxn << 1];
int ch[maxn << 1][60], sz, last;
int d1[maxn << 1], d2[maxn << 1];
void init() {
len[0] = 0;
link[0] = -1;
sz = 1;
last = 0;
memset(d1, 0, sizeof(d1));
memset(d2, 0, sizeof(d2));
memset(ch[0], 0, sizeof(ch[0]));
}
void extend(int c) {
int cur = sz++, p = last;
len[cur] = len[last] + 1;
d1[cur] = 1;
memset(ch[cur], 0, sizeof(ch[cur]));
while (p != -1 && !ch[p][c]) {
ch[p][c] = cur;
p = link[p];
}
if (p == -1) {
link[cur] = 0;
} else {
int q = ch[p][c];
if (len[p] + 1 == len[q]) {
link[cur] = q;
} else {
int clone = sz++;
len[clone] = len[p] + 1;
memcpy(ch[clone], ch[q], sizeof(ch[q]));
link[clone] = link[q];
while (p != -1 && ch[p][c] == q) {
ch[p][c] = clone;
p = link[p];
}
link[q] = link[cur] = clone;
}
}
last = cur;
}
struct node {
int v, next;
} ed[maxn << 1];
int head[maxn << 1], cnt = 0;
void add_edge(int u, int v) {
++cnt;
ed[cnt].v = v;
ed[cnt].next = head[u];
head[u] = cnt;
}
void dfs(int u) {
for (int i = head[u]; i; i = ed[i].next) {
int v = ed[i].v;
dfs(v);
d1[u] += d1[v];
}
}
int k;
char a[maxn], b[maxn];
ll ans;
void dfs2(int u) {
for (int i = head[u]; i; i = ed[i].next) {
int v = ed[i].v;
dfs2(v);
if (len[u] >= k) d2[u] += d2[v];
}
ans += 1ll * d1[u] * d2[u] * (len[u] - max(k, len[link[u]] + 1) + 1);
}
int main() {
while (~scanf("%d", &k) && k) {
init();
memset(head, 0, sizeof(head));
cnt = 0;
scanf("%s%s", a, b);
int lena = strlen(a);
int lenb = strlen(b);
for (int i = 0; i < lena; i++) {
extend(getc(a[i]));
}
for (int i = 1; i < sz; i++) {
add_edge(link[i], i);
}
dfs(0);
ans = 0;
int p = 0, nowlen = 0;
for (int i = 0; i < lenb; i++) {
int id = getc(b[i]);
if (ch[p][id]) p = ch[p][id], nowlen++;
else {
while (p != -1 && !ch[p][id]) p = link[p];
if (p == -1) nowlen = 0, p = 0;
else nowlen = len[p] + 1, p = ch[p][id];
}
if (nowlen >= k) {
ans += 1ll * (nowlen - max(k, len[link[p]] + 1) + 1) * d1[p];
if (len[link[p]] >= k) d2[link[p]]++;
}
}
dfs2(0);
cout << ans << endl;
}
return 0;
}