【题解】CF 1073 G. Another LCP

https://codeforces.com/problemset/problem/1073/G
题意:给定一长度为n的字符串s,并有q次询问,每次询问给定两个序列\(\{a_i\}\)(记长度为k)和\(\{b_j\}\)(记长度为l),输出\(\Sigma_{i=1}^{k} \Sigma_{j=1}^{l} LCP(s[a_i,n],s[b_i,n])\)

数据范围:\(1\leq n,\Sigma k, \Sigma l \leq 2e5\)

思路

求LCP最方便的应该还是后缀数组,处理出height数组后,则\(LCP(i,j)=height[i+1\sim j]\)中的最小值(i,j表示排序第i位和排序第j位的后缀。用ST表预处理区间最小值就可以做到O(1)询问。

但数据范围是\(1\leq \Sigma k, \Sigma l \leq 2e5\),显然如果暴力询问,那么是\(O(k^2)\)的,自然T飞。

按照这个思路,这道题实际上可以转化如下:

给定\(\{a_i\}\)\(\{b_j\}\),求由以上序列构成点对组成的区间的最小值的和。

所以……这tm明明就是一道数据结构题,和字符串有什么关系!?

对于这种计数问题,如果不能暴力枚举,那必然是可以合并枚举过程,将相同或者相近的部分一次性加起来,减少枚举次数。

那按照这个思路,两个什么样的区间会有相同的最小值?实际上只要包含了同样一个最小数即可。(意会一下,这句话不怎么严谨

所以,如果我们枚举整个大区间的最小值,那么任何包含这个位置的区间的最小值都是这个值。

那可否继续枚举第二小的值?当然可以!只要保证选取区间不包含刚刚那个最小数的位置,那么这些区间的最小值都是第二小的值。

以此类推,从小到大枚举,每次统计不包含之前枚举过的区间的个数,再乘当前值,便是这一部分的贡献。最后把每部分贡献加起来就是最终答案。

那么每次怎么统计区间个数?

设区间长度长度为\(l\),建立一个set,存放的是所有已经枚举过的位置。初始放入\(0\)\(l+1\)。之后最小值、第二小、第三小……的位置i,在set中找到\(i\)左边的第一个数\(mn\)\(i\)右边的第一个数\(mx\),统计出\(\{a_i\}\)序列中\(mn< a_i\leq i\)的个数和\(\{b_i\}\)序列中\(i<b_i< mx\)的个数(这两个个数,把出现位置标记为1,求前缀和即可O(1)询问),两个相乘,便是:

  • 包含位置i
  • 不包含之前已经枚举过的位置
  • 左端点来自\(a_i\),右端点来自\(b_j\)

的区间的个数(不包含只含位置\(i\)的区间)。如果要枚举左端点来自\(b_i\),右端点来自\(a_j\)的区间,那么也可以类似分析得到。至于只含位置i的区间,可以单独计算。

这样我们就解决了转化之后的问题,时间复杂度是\(O(l\log l)\)

但这样还是不可以,因为这里的\(l\)是区间长度,但题目中说的是给的端点数量\(\leq 2e5\),如果每次都只给两个点1和2e5,那么最坏复杂度反而成了\(1e5\times 2e5 \log 2e5\)

所以我们要做的就是……离散化,把不用的区间都合并成点,简单思路如下:(说不清楚,意会一下要点

  • 将本轮询问的所有\(\{a_i\}\)\(\{b_i\}\)都放到同一个序列里,从小到大排序。
  • 对于两个不相等的原始坐标i和j(i<j),如果i<j-1,那么对i,[i+1,j-1],j分别建新点。比如说3和200,就变成3一个点,4-199区间里的最小值作为一个点,200一个点。
  • 如果i=j-1,那就只建立i和j两个点。比如说3和4,显然3和4中间也没有数,不需要新加点。

这样两个原始端点之间最多只有一个新加点,区间长度最大也只有\(2*2e5\),再套用上面的办法,复杂度就对了。

程序中使用了set和排序等操作,渐进复杂度\(O(n\log n+(\Sigma k+\Sigma l)\log(\Sigma k+\Sigma l))\),但实际上用了很多次排序,所以常数大的飞起。

总结:

  • 写题太慢,这个题写了一天,debug了一年。
  • 细节太多了。最麻烦的地方有两个,一个是把LCP转化成上面区间询问最小值的问题,由于计算方法不一样,必须分成三部分。对LCP(i,i),预先统计;对LCP(i,j)(i<j),要把所有的i都+1,然后跑一遍,对于(i>j)的情况,得把j+1再跑一遍……另外一个是离散化,太难写了。最后代码快200行了……

AC代码:841ms,43836KB。中间有些地方(尤其是离散化部分)为了方便理解并让自己不要弄混,因此写的并不是最优化的版本,很多重复的代码。

#include<bits/stdc++.h>
#define LL long long 
using namespace std;
const int M=2e5+50,P=1e9+7;
char s[M];
int rk[M],sa[M],tp[M],tax[M],height[M],len,m=150,a[M],b[M];//rk(第一关键字)第i位的排名,sa是排名为i的位置,tp是第二关键字辅助用的,tax是桶数组;
void Rsort(){//用k*2排序的第二关键字和k排序的结果(rk,即k*2排序的第一关键字)得到k*2排序结果sa
    for(int i=1;i<=m;i++) 
        tax[i]=0;
    for(int i=1;i<=len;i++)//把第一关键字放到桶里
        tax[rk[tp[i]]]++;
    for(int i=1;i<=m;i++) 
        tax[i]+=tax[i-1];
    for(int i=len;i>=1;i--)//从第二关键字小的开始往sa里丢,先进入的在后面
        sa[tax[rk[tp[i]]]--]=tp[i];
}
void Suffix(){
    for(int i=1;i<=len;i++)
        rk[i]=s[i],tp[i]=i;//得到用2^0排序的rk,tp是随便搞的
    Rsort();//得到用2^0排序的sa
    for(int k=1;k<=len;k<<=1){
        //将用k排序的结果变成用k*2排序的第二关键字
        int num=0;
        for(int i=len-k+1;i<=len;i++)
            tp[++num]=i;
        for(int i=1;i<=len;i++)
            if(sa[i]>k)
                tp[++num]=sa[i]-k;
        Rsort(),swap(rk,tp);//用tp存下k排序的结果rk,下面用k*2排序的sa得到k*2排序的rk,用k排序的rk是为了找出相同部分
        rk[sa[1]]=num=1;
        for(int i=2;i<=len;i++)
            rk[sa[i]]=(tp[sa[i]]==tp[sa[i-1]]&&tp[sa[i]+k]==tp[sa[i-1]+k])?num:++num;//虽然sa下标不一样,但rk必须将排位相同的标记出来
        if(num==len)
            break;
        m=num;
    }
}
void getheight(){
    int k=0;
    for(int i=1,j;i<=len;i++){
        if(rk[i]==1)
            k=0;
        else{
            if(k)
                --k;
            j=sa[rk[i]-1];
            while(i+k<=len&&j+k<=len&&s[i+k]==s[j+k])
                ++k;
        }
        height[rk[i]]=k;
    }
}
int mi2[35],logg[M],n,f[M][35];
void initst(int n){
    mi2[0]=1,logg[0]=-1;
    for(int i=1;i<=30;++i)
        mi2[i]=mi2[i-1]*2;
    for(int i=1;i<=n;++i)
        logg[i]=logg[i/2]+1;
    for(int i=1;i<=n;++i)
        f[i][0]=height[i];
    for(int i=1;i<=logg[n];++i)
        for(int j=1;j<=n+1-mi2[i];++j)
            f[j][i]=min(f[j][i-1],f[j+mi2[i-1]][i-1]);
}
int query(int l,int r){
    int lg=logg[r-l+1];
    return min(f[l][lg],f[r-mi2[lg]+1][lg]);
}
int id[M],numb[M],numa[M];
vector<int> nh;
bool cmp(int x,int y){
    return nh[x]<nh[y];
};
LL sum(int *a,int la,int *b,int lb){//nh放的是各坐标对应的height最小值 
    nh.clear();
    nh.push_back(0);
    int pa=1,pb=1,p=0,ls;
    for(int i=1;i<=la;++i)
        a[i]+=1;//p记录当前最大坐标,ls记录上一次放进去的原坐标,用于判断是否要新建中间的区间 
    if (a[pa]<=b[pb])
        numb[++p]=0,numa[p]=1,nh.push_back(height[a[pa]]),ls=a[pa],++pa;
    else
        numb[++p]=1,numa[p]=0,nh.push_back(height[b[pb]]),ls=b[pb],++pb;
    while (pa<=la&&pb<=lb)
        if (a[pa]<=b[pb]){
            if (ls==a[pa]){
                numa[p]=1,++pa;
                continue;
            }
            else if (ls==a[pa]-1)//只需要插入a[pa] 
                numb[++p]=0,numa[p]=1,nh.push_back(height[a[pa]]),ls=a[pa],++pa;
            else//否则先插入上一个区间,再插入a[pa]
                nh.push_back(query(ls+1,a[pa]-1)),numb[++p]=0,numa[p]=0,
                numb[++p]=0,numa[p]=1,nh.push_back(height[a[pa]]),ls=a[pa],++pa;
        }
        else{
            if (ls==b[pb]){
                numb[p]=1,++pb;
                continue;
            }
            else if (ls==b[pb]-1)//只需要插入b[pb] 
                numb[++p]=1,numa[p]=0,nh.push_back(height[b[pb]]),ls=b[pb],++pb;
            else//否则先插入上一个区间,再插入b[pb]
                nh.push_back(query(ls+1,b[pb]-1)),numb[++p]=0,numa[p]=0,
                numb[++p]=1,numa[p]=0,nh.push_back(height[b[pb]]),ls=b[pb],++pb;
        }
    while (pa<=la)
        if (ls==a[pa]){
            numa[p]=1,++pa;
            continue;
        }
        else if (ls==a[pa]-1)//只需要插入a[pa] 
            numb[++p]=0,numa[p]=1,nh.push_back(height[a[pa]]),ls=a[pa],++pa;
        else//否则先插入上一个区间,再插入a[pa]
            nh.push_back(query(ls+1,a[pa]-1)),numb[++p]=0,numa[p]=0,
            numb[++p]=0,numa[p]=1,nh.push_back(height[a[pa]]),ls=a[pa],++pa;
    while (pb<=lb)
        if (ls==b[pb]){
            numb[p]=1,++pb;
            continue;
        }
        else if (ls==b[pb]-1)//只需要插入b[pb] 
            numb[++p]=1,numa[p]=0,nh.push_back(height[b[pb]]),ls=b[pb],++pb;
        else//否则先插入上一个区间,再插入b[pb]
            nh.push_back(query(ls+1,b[pb]-1)),numb[++p]=0,numa[p]=0,
            numb[++p]=1,numa[p]=0,nh.push_back(height[b[pb]]),ls=b[pb],++pb;
    for (int i=1;i<=p;++i)
        id[i]=i,numa[i]+=numa[i-1],numb[i]+=numb[i-1];
    sort(id+1,id+p+1,cmp);
    set<int> st;
    st.insert(0),st.insert(p+1);
    LL ans=0;
    for (int i=1;i<=p;++i){
        int &t=id[i];
        set<int>::iterator mn1=st.lower_bound(t);
        --mn1;
        int mn=*mn1,mx=*(st.upper_bound(t));
        ans+=(numa[t]-numa[mn])*1LL*(numb[mx-1]-numb[t-1])*nh[t];
        st.insert(t);
    }
    for(int i=1;i<=la;++i)
        a[i]-=1;
    return ans;
}
int main(){
    int n,q;
    scanf("%d%d",&len,&q); 
    scanf("%s",s+1);
    Suffix(),getheight();
    initst(len);
    for (int z=1;z<=q;++z){
        int k,l,c;
        scanf("%d%d",&k,&l);
        for (int i=1;i<=k;++i)
            scanf("%d",&a[i]);
        for (int i=1;i<=l;++i)
            scanf("%d",&b[i]);
        LL ans=0;
        int p1=1,p2=1;
        while (p1<=k&&p2<=l)//统计相等的 
            if (a[p1]==b[p2])
                ans+=len-a[p1]+1,++p1,++p2;
            else if (a[p1]>b[p2])
                ++p2;
            else
                ++p1;
        for (int i=1;i<=k;++i)
            a[i]=rk[a[i]];
        for (int i=1;i<=l;++i)
            b[i]=rk[b[i]];
        sort(a+1,a+k+1);
        sort(b+1,b+l+1);
        ans=ans+sum(a,k,b,l)+sum(b,l,a,k);
        printf("%lld\n",ans);
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/diorvh/p/11821333.html