[loj#3049] [十二省联考 2019] 字符串问题

题意简述

给定字符串 \(S\)
给定 \(n_a\)\(A\) 类串和 \(n_b\)\(B\) 类串,每个串都是 \(S\) 的子串。
给定 \(m\) 组支配关系 \((x,y)\) ,表示第 \(x\)\(A\) 串支配第 \(y\)\(B\) 串。

求一个长度最长的 \(T\) 串,存在一个分割 \(T=t_1+t_2+t_3+...t_k\) 满足:
1.分割中的每个串 \(t_i\) 均为 A 类串。
2.\(t_i\) 代表的 \(A\) 串所支配的某一个 \(B\) 串是 \(t_{i+1}\) 的前缀。

\(T\) 无限长则输出 \(-1\)
\(n_a,n_b,|S|,m \leq 2\times 10^5\)


想法

比较显然的想法是,\(A\) 类串向它可支配的 \(B\) 类串连边,\(B\) 类串向以它为前缀的 \(A\) 类串连边,在这个图上跑 \(dp\) 就行了。
联想到后缀自动机中的 后缀链接形成的 \(pa\) 树(后缀树),树从上到下相当于在串前面添加字符,也就是说有父子关系的点是有相同后缀的。
而此题中,要求的是前缀相同,于是把串倒过来加到后缀自动机里就好啦~

通过在后缀树上倍增找到每个 \(A\) 类串与 \(B\) 类串代表的点,之后连边。
注意可能有些长度不同的串对应于 \(SAM\) 上同一个节点,于是在每个节点开个 \(vector\) 装此点代表的 \(A\)\(B\) 串,之后按串长度排序。
每个点与它的祖先中(包括在后缀树上的祖先点与该点中长度比其短的点)的离它最近的 \(B\) 串点连边。这个过程可通过 \(dfs\) 完成。
之后 \(A\) 串点向它可支配的 \(B\) 串点连边。
新图上跑 \(dp\) ,当图中有环时输出 \(-1\)


总结

把串反过来加到 \(SAM\) 中是一个处理 “相同前缀” 问题的套路。
利用好后缀树,在后缀树上倍增找点也是套路。


代码

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>

using namespace std;

int read(){
    int x=0;
    char ch=getchar();
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return x;
}

const int N = 200005;
typedef long long ll;

int cnt,root,last,ch[N*2][26],len[N*2],pa[N*2];
int ed[N];
void ins(int c){
    int x=last,cur=++cnt;
    len[cur]=len[last]+1;
    for(;x && !ch[x][c];x=pa[x]) ch[x][c]=cur;
    if(!x) pa[cur]=root;
    else{
        int y=ch[x][c],ny;
        if(len[y]==len[x]+1) pa[cur]=y;
        else{
            ny=++cnt;
            len[ny]=len[x]+1;
            pa[ny]=pa[y]; pa[y]=pa[cur]=ny;
            for(int i=0;i<26;i++) ch[ny][i]=ch[y][i];
            for(;x && ch[x][c]==y;x=pa[x]) ch[x][c]=ny;
        }
    }
    last=cur;
}

struct node{
    int v;
    node *nxt;
}pool[N*6],*h[N*2],*son[N*2];
int cnt1;
void addedge(int u,int v){
    node *p=&pool[++cnt1];
    p->v=v;p->nxt=h[u];h[u]=p;
}
void addedge1(int u,int v){
    node *p=&pool[++cnt1];
    p->v=v;p->nxt=son[u];son[u]=p;
}

int f[N*2][20],in[N*2],que[N*2],hd,tl;
void getf(){
    hd=tl=0;
    for(int i=1;i<=cnt;i++) in[pa[i]]++;
    for(int i=1;i<=cnt;i++) if(!in[i]) que[tl++]=i;
    while(hd<tl){
        int u=que[hd++];
        in[pa[u]]--;
        if(in[pa[u]]==0) que[tl++]=pa[u];
    }
    for(int i=tl-1;i>=0;i--){
        int u=que[i];
        f[u][0]=pa[u];
        if(pa[u]) addedge1(pa[u],u);
        for(int j=1;j<20;j++) f[u][j]=f[f[u][j-1]][j-1];
    }
}
int find(int l,int r){
    int x=ed[l];
    for(int i=19;i>=0;i--) if(len[f[x][i]]>=r-l+1) x=f[x][i];
    return x;
}

char s[N];
int n,m,na,nb,w;
int A[N],B[N];

struct data{
    int id,len;
    data() { id=len=0; }
    data(int x,int y) { id=x; len=y; }
    bool operator < (const data &b) const{ return len<b.len || (len==b.len && id>b.id); }
};
vector<data> vv[N*2];

void build(int u,int pre){
    sort(vv[u].begin(),vv[u].end());
    for(int i=0;i<vv[u].size();i++){
        if(pre) addedge(pre,vv[u][i].id);
        if(vv[u][i].id>na) pre=vv[u][i].id;
    }
    for(node *p=son[u];p;p=p->nxt) build(p->v,pre);
}

int flag;
ll mx[N*2];
int vis[N*2],val[N*2];
ll dp(int u){
    if(mx[u]!=-1) return mx[u];
    int v;
    vis[u]=1;
    mx[u]=0;
    for(node *p=h[u];p;p=p->nxt)
        if(vis[v=p->v]==1) flag=0;
        else mx[u]=max(mx[u],dp(v));
    mx[u]+=val[u];
    vis[u]=2;
    return mx[u];
}

int main()
{
    int T=read(),u,l,r;
    while(T--){
        scanf("%s",s+1);
        n=strlen(s+1);
        root=last=++cnt;
        for(int i=n;i>=1;i--) ins(s[i]-'a');
        u=root;
        for(int i=n;i>=1;i--) {
            u=ch[u][s[i]-'a'];
            ed[i]=u;
        }
        getf();
        
        na=read();
        for(int i=1;i<=na;i++){
            l=read(); r=read();
            A[i]=find(l,r);
            vv[A[i]].push_back(data(i,r-l+1));
            val[i]=r-l+1;
        }
        nb=read();
        for(int i=1;i<=nb;i++){
            l=read(); r=read();
            B[i]=find(l,r);
            vv[B[i]].push_back(data(i+na,r-l+1));
        }
        build(root,0);
        w=na+nb;
        
        m=read();
        for(int i=0;i<m;i++){
            l=read(); r=read();
            addedge(l,r+na);
        }
        
        flag=1;
        for(int i=1;i<=w;i++) mx[i]=-1;
        ll ans=0;
        for(int i=1;i<=w;i++) ans=max(ans,dp(i));
        if(!flag) printf("-1\n");
        else printf("%lld\n",ans);
        
        //clear
        for(int i=1;i<=cnt;i++){
            pa[i]=len[i]=0;
            for(int j=0;j<26;j++) ch[i][j]=0;
            in[i]=0; son[i]=NULL;
            vv[i].clear();
        }
        cnt=0; cnt1=0;
        for(int i=1;i<=w;i++){
            h[i]=NULL;
            val[i]=vis[i]=0;
        }
    }
    
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/lindalee/p/12458547.html
今日推荐