「HAOI2016 找相同字符」「SA」「单调栈」「前缀和」

首先可以想到一个暴力的\(\mathcal{O(n^3)}\)算法:枚举\(\text{A}\)\(\text{B}\)的两个后缀,算出他们的最长公共前缀。

这样显然是对的,但是也显然可以用后缀数组优化。

\(\text{A}\)\(\text{B}\)两个串用一个没出现过的字符隔开然后连起来,对新串求后缀数组。那么对于原来的两个后缀,也可以表现为在这个串里对应位置后缀的LCP,也就是区间 height 的 min.

那么区间 height 的 min 就启发我们用单调栈来枚举这个 min 然后算贡献,对于这个min,左端点的选取是一个区间,右端点也是,故统计这两段内有多少对后缀在原串中一个在A,一个在B即可。

这个东西某个 sb 一开始写了主席树求,后来发现直接前缀和一下就可以了。

所以复杂度是\(\mathcal{O(n log_2 n)}\),瓶颈在求sa。

具体为啥我调了一天,大概是因为sa求错了一直wa10...

#include <bits/stdc++.h>

#pragma GCC optimize("Ofast","-funroll-loops","-fdelete-null-pointer-checks")
#pragma GCC target("ssse3","sse3","sse2","sse","avx2","avx")

#define rep(i, l, r) for (int i = (l); i <= (r); ++i)
#define per(i, r, l) for (int i = (r); i >= (l); --i)
using namespace std;

typedef long long ll;
typedef pair <int, int> pii;
typedef vector <int> vi;

int gi() {
    int f = 1, x = 0; char ch = getchar();
    while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
    while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
    return f * x;
}
const int N = 400005; 
char s1[N >> 1], s2[N >> 1], s[N];
int n, sa[N], rk[N], id[N], px[N], cnt[N], rk_[N], h[N], l[N], r[N], st[N], tp;
bool cmp(int x, int y, int w) {
    return rk_[x] == rk_[y] && rk_[x + w] == rk_[y + w];
}
void get_SA() {
    int i, j, m = 300, p, w;
    for (i = 1; i <= n; ++i) ++cnt[rk[i] = s[i]];
    for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
    for (i = n; i >= 1; --i) sa[cnt[rk[i]]--] = i;
    for (w = 1; w < n; w <<= 1, m = p) {
        for (p = 0, i = n; i > n - w; --i) id[++p] = i;
        for (i = 1; i <= n; ++i) if (sa[i] > w) id[++p] = sa[i] - w;
        memset (cnt, 0, sizeof(cnt));
        for (i = 1; i <= n; ++i) ++cnt[px[i] = rk[id[i]]];
        for (i = 1; i <= m; ++i) cnt[i] += cnt[i - 1];
        for (i = n; i >= 1; --i) sa[cnt[px[i]]--] = id[i];
        memcpy(rk_, rk, sizeof(rk));
        for (p = 0, i = 1; i <= n; ++i) 
            rk[sa[i]] = (cmp(sa[i - 1], sa[i], w) ? p : ++p);
    }
    for (i = 1, j = 0; i <= n; ++i) {
        if (rk[i] == 1) continue;
        int now = sa[rk[i] - 1];
        while (s[i + j] == s[now + j]) ++j;
        h[rk[i]] = j;
        if (j) --j;
    }
}
int S1[N], S2[N];
int main() {
    scanf("%s%s",s1 + 1, s2 + 1);
    int len = strlen(s1 + 1), len_ = strlen(s2 + 1);
    n = len + len_ + 1;
    rep (i, 1, n) {
        if (i <= len) s[i] = s1[i];
        else if (i == len + 1) s[i] = '$';
        else s[i] = s2[i - len - 1];
    }
    get_SA();
    rep (i, 1, n) l[i] = 0, r[i] = n + 1;
    rep (i, 1, n) {
        while (tp && h[st[tp]] >= h[i]) --tp;
        if (tp) l[i] = st[tp];
        st[++tp] = i;
    }
    tp = 0;
    per (i, n, 1) {
        while (tp && h[st[tp]] > h[i]) --tp;
        if (tp) r[i] = st[tp];
        st[++tp] = i;
    }
    ll ans = 0; 
    rep (i, 1, n) {
        S1[i] = S1[i - 1] + (1 <= sa[i] && sa[i] <= len);
        S2[i] = S2[i - 1] + (len + 2 <= sa[i] && sa[i] <= n); 
    }
    rep (i, 1, n) {
        int l_ = l[i] + 1, r_ = r[i] - 1;
        ans += 1ll * h[i] * (S1[i - 1] - S1[l_ - 2]) * (S2[r_] - S2[i - 1]);
        ans += 1ll * h[i] * (S2[i - 1] - S2[l_ - 2]) * (S1[r_] - S1[i - 1]);
    } 
    cout << ans << '\n'; 
  return 0;
}

猜你喜欢

转载自www.cnblogs.com/LiM-817/p/12305747.html