兔子的字符串 - 后缀数组 - 二分 - 贪心

题目大意:给你个字符串s,将其划分成不超过k段使得每段的字典序最大子串的最大的一个t最小,问t是啥。1e5。
题解:考虑二分,朴素的二分是确定每一位,T飞。发现答案是s的子串,而s只有O(n^2)个子串,二分这个即可。二分完后对每个后缀分类讨论一下可以转化为给定若干区间选尽量少的点使得每个区间有至少一个点,经典问题:先不管那些覆盖别的区间的区间然后左端点排序然后贪心的向右选即可。精细的实现一下可以做到一个log。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define N 100010
#define LOG 20
#define lint long long
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
int wa[N],wb[N],cnt[N],v[N],sa[N],rk[N],h[N];
char s[N];int a[N],R[N],Log[N],mnv[N][LOG];
inline int cmp(int *a,int x,int y,int l)
{
    return a[x]==a[y]&&a[x+l]==a[y+l];
}
inline int getSA(int *a,int n,int m)
{
    int *x=wa,*y=wb,i,j,p;
    memset(wa,0,sizeof(wa));
    memset(wb,0,sizeof(wb));
    for(i=1;i<=m;i++) cnt[i]=0;
    for(i=1;i<=n;i++) cnt[x[i]=a[i]]++;
    for(i=1;i<=m;i++) cnt[i]+=cnt[i-1];
    for(i=n;i;i--) sa[cnt[x[i]]--]=i;
    for(p=0,j=1;p<=n;j<<=1,m=p)
    {
        for(p=0,i=n-j+1;i<=n;i++) y[++p]=i;
        for(i=1;i<=n;i++)
            if(sa[i]>j) y[++p]=sa[i]-j;
        for(i=1;i<=m;i++) cnt[i]=0;
        for(i=1;i<=n;i++) cnt[v[i]=x[y[i]]]++;
        for(i=1;i<=m;i++) cnt[i]+=cnt[i-1];
        for(i=n;i;i--) sa[cnt[v[i]]--]=y[i];
        for(swap(x,y),p=i=1;i<=n;i++)
            x[sa[i]]=(cmp(y,sa[i],sa[i-1],j)?p-1:(p++));
    }
    for(i=1;i<=n;i++) rk[sa[i]]=i;
    for(i=1,p=0;i<=n;h[rk[i++]]=p)
        for((p?p--:0),j=sa[rk[i]-1];a[i+p]==a[j+p];p++);
    return 0;
}
inline int getkths(int n,lint k,int &x,int &y)
{
    for(int i=1;i<=n;i++)
        if(k<=n-sa[i]+1-h[i]) return x=sa[i],y=h[i]+k,0;
        else k-=n-sa[i]+1-h[i];
    return x=y=0;
}
inline int LCP(int x,int y)
{
    x=rk[x],y=rk[y];if(x>y) swap(x,y);
    x++;int k=Log[y-x+1];
    return min(mnv[x][k],mnv[y-(1<<k)+1][k]);
}
inline int calc(lint k,int n)
{
    int p,l;getkths(n,k,p,l);
    for(int i=1;i<=n;i++) if(a[i]>a[p]) return n+1;
    for(int i=1;i<=n;i++)
        if(i==p) R[i]=i+l-1;
        else{
            int x=min(LCP(p,i),l);
            if(x==l||a[p+x]<a[i+x]) R[i]=i+x-1;
            else R[i]=n;
        }
    for(int i=n-1;i>=1;i--) R[i]=min(R[i],R[i+1]);
    int x=0,ans=0;
    for(int i=1;i<=n;i++) if(i>x) x=R[i],ans++;
    return ans;
}
int main()
{
    int k,n;scanf("%d%s",&k,s+1),n=(int)strlen(s+1);
    for(int i=1;i<=n;i++) a[i]=s[i]-'a'+1;getSA(a,n,27);
    for(int i=2;i<=n;i++) mnv[i][0]=h[i];
    for(int j=1;(1<<j)<=n;j++)
        for(int i=2;i+(1<<j)-1<=n;i++)
            mnv[i][j]=min(mnv[i][j-1],mnv[i+(1<<(j-1))][j-1]);
    for(int i=2;i<=n;i++) Log[i]=Log[i>>1]+1;
    lint L=1,R=0,mid=(L+R)>>1;
    for(int i=1;i<=n;i++) R+=n-sa[i]+1-h[i];
    while(L<=R)
    {
        if(calc(mid,n)<=k) R=mid-1;
        else L=mid+1;mid=(L+R)>>1;
    }
    int x,y;getkths(n,L,x,y);
    for(int i=x;i<x+y;i++) printf("%c",s[i]);
    return !printf("\n");
}

猜你喜欢

转载自blog.csdn.net/Mys_C_K/article/details/82224423