[HAOI2016]找相同字符 题解

传送门

开始对后缀数组上瘾了。。。

题意:给两个字符串,求公共子串数量。

把两个串拼在一起并且在中间加一个分隔符,然后统计出这个长串有多少对位置不同的相等子串。这些子串要么同时来自 S 1 S_1 ,要么同时来自 S 2 S_2 ,要么一个来自 S 1 S_1 一个来自 S 2 S_2 ;而我们需要的答案是最后一类。所以我们只要对 S 1 S_1 S 2 S_2 分别求一次,用总数减去它们的答案即可。

现在考虑如何求一个串 S S 有多少对相等的子串。求出它的后缀数组以及 h e i g h t \mathrm{height} 数组,那么问题就是要求

i = 1 n j = i + 1 n l c p ( s a i , s a j ) \sum\limits_{i=1}^n\sum\limits_{j=i+1}^n\mathrm{lcp}(\mathrm{sa}_i,\mathrm{sa}_j)

= i = 2 n j = i n min k = i j h e i g h t k =\sum\limits_{i=2}^n\sum\limits_{j=i}^n\min\limits_{k=i}^j\mathrm{height}_k

考虑从左向右加入数,每次加入一个数 h e i g h t i \mathrm{height}_i ,想办法维护出以它为右端点,以它左边每个位置为左端点的区间最小值之和。这个东西可以借助一个单调栈,维护:

s u m = j = 1 t o p h e i g h t s t k j ( s t k j s t k j 1 ) sum=\sum\limits_{j=1}^{top}\mathrm{height}_{\mathrm{stk}_j}\cdot(\mathrm{stk}_j-\mathrm{stk}_{j-1})

那么每次加入一个数,先计算它答案的贡献 s u m + h e i g h t i ( i s t k t o p ) sum+\mathrm{height}_i\cdot(i-\mathrm{stk}_{top}) ,然后将它加入栈顶。

在栈上加入或删除元素的同时从 s u m sum 中加上或减去相应的值即可。

其实还有一些其它做法:我们可以求出以每个位置为最小值点,向左向右分别最远能延伸多远,然后直接乘法原理计算即可,这个也需要用单调栈求。

反正怎么开心怎么玩吧。

#include <cctype>
#include <cstdio>
#include <climits>
#include <algorithm>
#include <cstring>

template <typename T> void write(T x) {
    if (x < 0) x = -x, putchar('-');
    if (x > 9) write(x / 10);
    putchar(x % 10 + 48);
}
template <typename T> inline void writeln(T x) { write(x); puts(""); }
template <typename T> inline bool chkmin(T& x, const T& y) { return y < x ? (x = y, true) : false; }
template <typename T> inline bool chkmax(T& x, const T& y) { return x < y ? (x = y, true) : false; }

typedef long long LL;

const int maxn = 4e5 + 207;

char s1[maxn], s2[maxn];
int sa[maxn], rank[maxn], tmp[maxn], tax[maxn], height[maxn];
int stk[maxn], top;
int n1, n2, sigma;

inline void rsort(int n) {
    std::fill(tax, tax + sigma + 1, 0);
    for (int i = 1; i <= n; ++i) ++tax[rank[i]];
    for (int i = 1; i <= sigma; ++i) tax[i] += tax[i - 1];
    for (int i = n; i; --i) sa[tax[rank[tmp[i]]]--] = tmp[i];
}
inline void getsa(int n, char *s) {
    for (int i = 1; i <= n; ++i)
        rank[i] = s[i] - 'a' + 1, tmp[i] = i;
    rsort(n);
    for (int w = 1, p = 0; p < n; sigma = p, w <<= 1) {
        int cnt = 0;
        for (int i = 1; i <= w; ++i) tmp[++cnt] = n - w + i;
        for (int i = 1; i <= n; ++i) if (sa[i] > w) tmp[++cnt] = sa[i] - w;
        rsort(n); std::swap(tmp, rank);
        rank[sa[1]] = p = 1;
        for (int i = 2; i <= n; ++i)
            rank[sa[i]] = tmp[sa[i]] == tmp[sa[i - 1]] && tmp[sa[i] + w] == tmp[sa[i - 1] + w] ? p : ++p;
    }
}
inline void getheight(int n, char *s) {
    for (int i = 1, k = 0; i <= n; ++i) {
        if (rank[i] == 1) { height[rank[i]] = k = 0; continue; }
        if (k) --k;
        int j = sa[rank[i] - 1];
        while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) ++k;
        height[rank[i]] = k;
    }
}
inline LL calc(int n, char *s) {
    getsa(n, s);
    getheight(n, s);
    LL sum = 0, ans = 0;
    stk[top = 0] = 1;
    for (int i = 2; i <= n; ++i) {
        while (top && height[stk[top]] > height[i]) {
            sum -= 1ll * height[stk[top]] * (stk[top] - stk[top - 1]);
            --top;
        }
        ans += sum + 1ll * height[i] * (i - stk[top]);
        stk[++top] = i;
        sum += 1ll * height[stk[top]] * (stk[top] - stk[top - 1]);
    }
    return ans;
}

int main() {
    scanf("%s", s1 + 1);
    scanf("%s", s2 + 1);
    n1 = strlen(s1 + 1);
    n2 = strlen(s2 + 1);
    LL ans = 0;
    sigma = 26;
    ans -= calc(n1, s1);
    sigma = 26;
    ans -= calc(n2, s2);
    sigma = 27;
    s1[++n1] = 'z' + 1;
    for (int i = n1 + 1, j = 1; j <= n2; ++i, ++j) s1[i] = s2[j];
    ans += calc(n1 + n2, s1);
    writeln(ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/qq_39677783/article/details/89789382