不同子串个数(后缀数组)

题目

传送门

题解

后缀数组的经典应用,重点在于这一句:ans+=(ll)(n-sa[i])-height[i];

代码

#include <cstdio>
#include <cstring>
#include <algorithm>
#include <iostream>
using namespace std;
#define ll long long
const int maxn=1e6;

char ch[maxn];
int n,m=200,x[maxn],y[maxn],sa[maxn],c[maxn],rnk[maxn],height[maxn];
ll ans;

void debug()
{
    printf("下标 "); for (int i=0; i<n; i++) printf("%d  ",i); printf("\n");
    printf("sa:  "); for (int i=0; i<n; i++) printf("%d  ",sa[i]); printf("\n");
    printf("x:   "); for (int i=0; i<n; i++) printf("%d  ",x[i]); printf("\n");
    printf("y:   "); for (int i=0; i<n; i++) printf("%d  ",y[i]); printf("\n");
    printf("c:   "); for (int i=0; i<n; i++) printf("%d  ",c[i]); printf("\n"); 
    printf("\n");
}

void getSA()
{
    for (int i=0; i<m; i++) c[i]=0;
    for (int i=0; i<n; i++) c[x[i]=ch[i]]++;
    for (int i=1; i<m; i++) c[i]+=c[i-1];
    for (int i=n-1; i>=0; i--) sa[--c[x[i]]]=i;

    for (int k=1; k<=n; k<<=1)
    {
        int p=0;
        for (int i=n-k; i<n; i++) y[p++]=i;
        for (int i=0; i<n; i++) if (sa[i]>=k) y[p++]=sa[i]-k;

        for (int i=0; i<m; i++) c[i]=0;
        for (int i=0; i<n; i++) c[x[i]]++;
        for (int i=1; i<m; i++) c[i]+=c[i-1];
        for (int i=n-1; i>=0; i--) sa[--c[x[y[i]]]]=y[i];

        swap(x,y);
        p=1; x[sa[0]]=0;
        for (int i=1; i<n; i++)
            x[sa[i]] = y[sa[i-1]]==y[sa[i]]&&((sa[i-1]+k>=n?-1:y[sa[i-1]+k])==(sa[i]+k>=n?-1:y[sa[i]+k])) ? p-1:p++;
        if (p>n) break;
        m=p;
    }
}

void getHeight()
{
    for (int i=0; i<n; i++) rnk[sa[i]]=i;
    int k=0; height[0]=0;
    for (int i=0; i<n; i++)
    {
        if (!rnk[i]) continue;
        if (k) k--;
        int j=sa[rnk[i]-1];
        while (i+k<n && j+k<n && ch[i+k]==ch[j+k]) k++;
        height[rnk[i]]=k;
    }
}

int main()
{
    scanf("%d",&n);
    scanf("%s",ch);
    n=strlen(ch);
    getSA();
    getHeight();
//  for (int i=0; i<n; i++) printf("%d ",sa[i]); printf("\n");
//  for (int i=0; i<n; i++) printf("%d ",rnk[i]); printf("\n");
//  for (int i=0; i<n; i++) printf("%d ",height[i]); printf("\n");
    for (int i=0; i<n; i++) ans+=(ll)(n-sa[i])-height[i];
    printf("%lld",ans);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/A_Comme_Amour/article/details/79811754