Luogu2178 NOI2015 品酒大会 SA、并查集

传送门


感觉题目讲的很不清楚……

题目意思就是给出一个长度为\(n\)的字符串,求对于\(r=0,1,...,n-1\),求出\(LCP(suffix_p,suffix_q) \geq r\)的无序数对\((p,q)\)的数目,并令一对无序数对的价值为\(val_p \times val_q\),则还要求对于每一个\(r\),所有满足上述条件的无序数对中的最大价值

跟后缀\(LCP\)长度有关,直接上\(SA\)。求出\(sa\)数组和\(height\)数组,我们考虑如何实现对于每一个\(r\)的询问快速求出答案。不妨将\(r\)从大到小求解,那么对于某一个后缀\(sa_k\),满足\(LCP(suffix_{sa_p} , suffix_{sa_k}) \geq r\)\(p\)一定是一段区间,而且这一段区间随着\(r\)的缩小不断增大。

然后我们考虑如何拓展区间。考虑对于\(height_k=q\),当\(r>q\)的时候\(k\)位置两端的区间不会越过\(k-1\)\(k\),而当\(r \leq q\)时这两段区间就会合成一段区间。这个显然是可以使用并查集维护的,并且可以比较轻松地在并查集上维护最大价值。

#include<bits/stdc++.h>
#define mid ((l + r) >> 1)
#define lch Tree[x].l
#define rch Tree[x].r
//This code is written by Itst
using namespace std;

inline int read(){
    int a = 0;
    char c = getchar();
    bool f = 0;
    while(!isdigit(c) && c != EOF){
        if(c == '-')
            f = 1;
        c = getchar();
    }
    if(c == EOF)
        exit(0);
    while(isdigit(c)){
        a = (a << 3) + (a << 1) + (c ^ '0');
        c = getchar();
    }
    return f ? -a : a;
}

const int MAXN = 3e5 + 10;
int fa[MAXN] , val[MAXN] , valMax[MAXN][2] , valMin[MAXN][2];
int sa[MAXN] , rk[MAXN] , pot[MAXN] , tp[MAXN << 1] , h[MAXN];
int ind[MAXN] , size[MAXN] , N , maxN = 26;
char s[MAXN];
long long Max , cnt , ans[MAXN][2];

int find(int x){
    return fa[x] == x ? x : (fa[x] = find(fa[x]));
}

void Debug(){
    for(int i = 1 ; i <= N ; ++i)
        cout << sa[i] << ' ';
    cout << endl;
    for(int i = 1 ; i <= N ; ++i)
        cout << ind[i] << ' ';
    cout << endl << endl;
}

void input(){
    N = read();
    scanf("%s" , s + 1);
    for(int i = 1 ; i <= N ; ++i){
        val[i] = read();
        if(val[i] < 0)
            valMin[i][0] = val[i];
    }
}

void sort(int p){
    memset(pot , 0 , sizeof(pot));
    for(int i = 1 ; i <= N ; ++i)
        ++pot[rk[i]];
    for(int i = 1 ; i <= maxN ; ++i)
        pot[i] += pot[i - 1];
    for(int i = 1 ; i <= N ; ++i)
        sa[++pot[rk[tp[i]] - 1]] = tp[i];
    memcpy(tp , rk , sizeof(int) * (N + 1));
    for(int i = 1 ; i <= N ; ++i)
        rk[sa[i]] = rk[sa[i - 1]] + (tp[sa[i]] != tp[sa[i - 1]] || tp[sa[i] + p] != tp[sa[i - 1] + p]);
    maxN = rk[sa[N]];
}

bool cmp(int a , int b){
    return h[a] < h[b];
}

void init(){
    memset(valMax , -0x3f , sizeof(valMax));
    Max = -1ll * 0x3f3f3f3f * 0x3f3f3f3f;
    for(int i = 1 ; i <= N ; ++i)
        rk[tp[i] = i] = s[i] - 'a' + 1;
    sort(0);
    for(int i = 1 ; i <= N && maxN < N ; i <<= 1){
        int cnt = 0;
        for(int j = 1 ; j <= i ; ++j)
            tp[++cnt] = N - i + j;
        for(int j = 1 ; j <= N ; ++j)
            if(sa[j] > i)
                tp[++cnt] = sa[j] - i;
        sort(i);
    }
    for(int i = 1 ; i <= N ; ++i){
        if(rk[i] == 1)
            continue;
        int t = rk[i];
        h[t] = max(0 , h[rk[i - 1]] - 1);
        while(s[sa[t] + h[t]] == s[sa[t - 1] + h[t]])
            ++h[t];
        ind[t] = t;
    }
    sort(ind + 2 , ind + N + 1 , cmp);
    for(int i = 1 ; i <= N ; ++i){
        fa[i] = i;
        size[i] = 1;
        valMax[i][0] = val[i];
    }
}

inline void merge(int x , int y){
    fa[x] = y;
    int num[4] = {valMax[x][0] , valMax[x][1] , valMax[y][0] , valMax[y][1]};
    sort(num , num + 4);
    valMax[y][0] = num[3];
    valMax[y][1] = num[2]; 
    Max = max(Max , 1ll * valMax[y][0] * valMax[y][1]);
    num[0] = valMin[x][0];
    num[1] = valMin[x][1];
    num[2] = valMin[y][0];
    num[3] = valMin[y][1];
    sort(num , num + 4);
    valMin[y][0] = num[0];
    valMin[y][1] = num[1];
    if(1ll * valMin[y][0] * valMin[y][1])
        Max = max(Max , 1ll * valMin[y][0] * valMin[y][1]);
    cnt -= 1ll * size[x] * (size[x] - 1) / 2 + 1ll * size[y] * (size[y] - 1) / 2;
    size[y] += size[x];
    cnt += 1ll * size[y] * (size[y] - 1) / 2;
}

void work(){
    int p = N;
    for(int i = N - 1 ; i >= 0 ; --i){
        while(p > 1 && h[ind[p]] == i){
            merge(find(sa[ind[p]]) , find(sa[ind[p] - 1]));
            --p;
        }
        if(cnt){
            ans[i][0] = cnt;
            ans[i][1] = Max;
        }
    }
}

void output(){
    for(int i = 0 ; i <= N - 1 ; ++i)
        cout << ans[i][0] << ' ' << ans[i][1] << '\n';
}

int main(){
    input();
    init();
    work();
    output();
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Itst/p/10211549.html