开始对后缀数组上瘾了。。。
题意:给两个字符串,求公共子串数量。
把两个串拼在一起并且在中间加一个分隔符,然后统计出这个长串有多少对位置不同的相等子串。这些子串要么同时来自 ,要么同时来自 ,要么一个来自 一个来自 ;而我们需要的答案是最后一类。所以我们只要对 和 分别求一次,用总数减去它们的答案即可。
现在考虑如何求一个串 有多少对相等的子串。求出它的后缀数组以及 数组,那么问题就是要求
考虑从左向右加入数,每次加入一个数 ,想办法维护出以它为右端点,以它左边每个位置为左端点的区间最小值之和。这个东西可以借助一个单调栈,维护:
那么每次加入一个数,先计算它答案的贡献 ,然后将它加入栈顶。
在栈上加入或删除元素的同时从 中加上或减去相应的值即可。
其实还有一些其它做法:我们可以求出以每个位置为最小值点,向左向右分别最远能延伸多远,然后直接乘法原理计算即可,这个也需要用单调栈求。
反正怎么开心怎么玩吧。
#include <cctype>
#include <cstdio>
#include <climits>
#include <algorithm>
#include <cstring>
template <typename T> void write(T x) {
if (x < 0) x = -x, putchar('-');
if (x > 9) write(x / 10);
putchar(x % 10 + 48);
}
template <typename T> inline void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }
typedef long long LL;
const int maxn = 4e5 + 207;
char s1[maxn], s2[maxn];
int sa[maxn], rank[maxn], tmp[maxn], tax[maxn], height[maxn];
int stk[maxn], top;
int n1, n2, sigma;
inline void rsort(int n) {
std::fill(tax, tax + sigma + 1, 0);
for (int i = 1; i <= n; ++i) ++tax[rank[i]];
for (int i = 1; i <= sigma; ++i) tax[i] += tax[i - 1];
for (int i = n; i; --i) sa[tax[rank[tmp[i]]]--] = tmp[i];
}
inline void getsa(int n, char *s) {
for (int i = 1; i <= n; ++i)
rank[i] = s[i] - 'a' + 1, tmp[i] = i;
rsort(n);
for (int w = 1, p = 0; p < n; sigma = p, w <<= 1) {
int cnt = 0;
for (int i = 1; i <= w; ++i) tmp[++cnt] = n - w + i;
for (int i = 1; i <= n; ++i) if (sa[i] > w) tmp[++cnt] = sa[i] - w;
rsort(n); std::swap(tmp, rank);
rank[sa[1]] = p = 1;
for (int i = 2; i <= n; ++i)
rank[sa[i]] = tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w] ? p : ++p;
}
}
inline void getheight(int n, char *s) {
for (int i = 1, k = 0; i <= n; ++i) {
if (rank[i] == 1) { height[rank[i]] = k = 0; continue; }
if (k) --k;
int j = sa[rank[i] - 1];
while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) ++k;
height[rank[i]] = k;
}
}
inline LL calc(int n, char *s) {
getsa(n, s);
getheight(n, s);
LL sum = 0, ans = 0;
stk[top = 0] = 1;
for (int i = 2; i <= n; ++i) {
while (top && height[stk[top]] > height[i]) {
sum -= 1ll * height[stk[top]] * (stk[top] - stk[top - 1]);
--top;
}
ans += sum + 1ll * height[i] * (i - stk[top]);
stk[++top] = i;
sum += 1ll * height[stk[top]] * (stk[top] - stk[top - 1]);
}
return ans;
}
int main() {
scanf("%s", s1 + 1);
scanf("%s", s2 + 1);
n1 = strlen(s1 + 1);
n2 = strlen(s2 + 1);
LL ans = 0;
sigma = 26;
ans -= calc(n1, s1);
sigma = 26;
ans -= calc(n2, s2);
sigma = 27;
s1[++n1] = 'z' + 1;
for (int i = n1 + 1, j = 1; j <= n2; ++i, ++j) s1[i] = s2[j];
ans += calc(n1 + n2, s1);
writeln(ans);
return 0;
}