【NOI2016】【BZOJ4650】【UOJ219】【LOJ2083】优秀的拆分

【题目链接】

【前置技能】

  • 后缀数组

【题解】

  • 首先,我们将优秀的拆分拆成一半来看,令 a [ p o s ] 表示从 p o s 开始的 A A 串的个数,令 b [ p o s ] 表示以 p o s 结尾的 A A 串的个数,那么答案就是 i = 1 L E N 1 a [ i + 1 ] b [ i ]
  • 那么接下来的问题就是如何求出 a 数组和 b 数组。我们考虑一个长度为 l e n A 串,我们在原串的 1 l e n , 2 l e n , 3 l e n 处设置断点,那么 A A 串的两个 A 就会分别经过相邻的两个断点,且断点处是 A 中的同一位。那么求出相邻两个断点的 l c p l c s ,如果 l c p + l c s 1 l e n 那么就存在 A A 串,差分一下统计入两个数组即可。注意,求的 l c p l c s 都要和 l e n 取min,避免串的重复计算。
  • 时间复杂度 O ( N l o g N )

【代码】

#include<bits/stdc++.h>
#define INF 0x3f3f3f3f
#define LL  long long
#define MAXN    60100
#define MAXLOG  18
using namespace std;
int n, Q, a[MAXN], b[MAXN], rnk[MAXN], sa[MAXN], hei[MAXN], st[MAXN][MAXLOG + 2], LOG[MAXN];
char s[MAXN];
LL ans;

template <typename T> void chkmin(T &x, T y){x = min(x, y);}
template <typename T> void chkmax(T &x, T y){x = max(x, y);}
template <typename T> void read(T &x){
    x = 0; int f = 1; char ch = getchar();
    while (!isdigit(ch)) {if (ch == '-') f = -1; ch = getchar();}
    while (isdigit(ch)) {x = x * 10 + ch - '0'; ch = getchar();}
    x *= f;
}

void suffix(int n){
    static int x[MAXN], y[MAXN], rk[MAXN], cnt[MAXN];
    memset(cnt, 0, sizeof(cnt));
    for (int i = 1; i <= n; ++i)
        ++cnt[s[i] - 'a'];
    for (int i = 1; i < 27; ++i)
        cnt[i] += cnt[i - 1];
    for (int i = n; i >= 1; --i)
        sa[cnt[s[i] - 'a']--] = i;
    rnk[sa[1]] = 1;
    for (int i = 2; i <= n; ++i)
        rnk[sa[i]] = rnk[sa[i - 1]] + (s[sa[i]] != s[sa[i - 1]]);
    for (int len = 1; rnk[sa[n]] != n; len <<= 1){
        for (int i = 1; i <= n; ++i)
            x[i] = rnk[i], y[i] = (i + len <= n) ? rnk[i + len] : 0;
        memset(cnt, 0, sizeof(cnt));
        for (int i = 1; i <= n; ++i)
            ++cnt[y[i]];
        for (int i = 1; i <= n; ++i)
            cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i)
            rk[cnt[y[i]]--] = i;
        memset(cnt, 0, sizeof(cnt));
        for (int i = 1; i <= n; ++i)
            ++cnt[x[i]];
        for (int i = 1; i <= n; ++i)
            cnt[i] += cnt[i - 1];
        for (int i = n; i >= 1; --i)
            sa[cnt[x[rk[i]]]--] = rk[i];
        rnk[sa[1]] = 1;
        for (int i = 2; i <= n; ++i)
            rnk[sa[i]] = rnk[sa[i - 1]] + (x[sa[i]] != x[sa[i - 1]] || y[sa[i]] != y[sa[i - 1]]);
    }
}

void gethei(int n){
    int cur = 0;
    for (int i = 1; i <= n; ++i){
        if (cur) --cur;
        for (int j = sa[rnk[i] + 1]; s[i + cur] == s[j + cur]; ++cur) ;
        hei[rnk[i]] = cur;
    }
}

void getrmq(int n){
    for (int i = 1; i <= n; ++i)
        st[i][0] = hei[i];
    for (int len = 1; len < MAXLOG; ++len){
        for (int i = 1; i + (1 << len) - 1 <= n; ++i)
            st[i][len] = min(st[i][len - 1], st[i + (1 << (len - 1))][len - 1]);
    }
}

int rmq(int x, int y){
    if (x > y) swap(x, y);
    int len = LOG[y - x + 1];
    return min(st[x][len], st[y - (1 << len) + 1][len]);
}

void adda(int l, int r){
    ++a[l], --a[r + 1];
}

void addb(int l, int r){
    ++b[l], --b[r + 1];
}

int lcp(int l, int r){
    l = rnk[l], r = rnk[r];
    if (l > r) swap(l, r); --r;
    return rmq(l, r);
}

int lcs(int l, int r){
    l = rnk[2 * n - l + 2], r = rnk[2 * n - r + 2];
    if (l > r) swap(l, r); --r;
    return rmq(l, r);
}

void work(){
    suffix(2 * n + 1);
    gethei(2 * n + 1);
    getrmq(2 * n + 1);
    for (int len = 1; len <= n / 2; ++len){
        for (int i = 1, j = i + 1; j * len <= n; ++i, ++j){
            int suf = min(lcp(i * len, j * len), len), pre = min(lcs(i * len, j * len), len);
            if (pre + suf - 1 >= len){
                adda(i * len - pre + 1, i * len + suf - len);
                addb(j * len - pre + len, j * len + suf - 1);
            }
        }
    }
    for (int i = 1; i <= n; ++i)
        a[i] += a[i - 1], b[i] += b[i - 1];
}

int main(){
    LOG[1] = 0;
    for (int i = 2; i < MAXN; ++i)
        LOG[i] = LOG[i - 1] + (i == (1 << (LOG[i - 1] + 1)));
    read(Q);
    while (Q--){
        scanf("%s", s + 1);
        n = strlen(s + 1);
        s[n + 1] = (char)'z' + 1;
        for (int i = 1; i <= n; ++i)
            s[n + i + 1] = s[n - i + 1];
        memset(a, 0, sizeof(a));
        memset(b, 0, sizeof(b));
        memset(sa, 0, sizeof(sa));
        work();
        ans = 0;
        for (int i = 1; i <= n; ++i)
            ans = ans + 1ll * b[i - 1] * a[i];
        printf("%lld\n", ans);
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/six_solitude/article/details/81177809
今日推荐