codeforces 1037H - Security (后缀自动机 + 线段树合并)

题解:首先分析,要大于给出的模式串并且尽可能小,那么一定是优先找和给出的模式串公共前缀尽可能长的字串,假设模式串 \(t\) 的长度为 \(tlen\)

\(t[tlen + 1] = a - 1\) ,那么思路的流程大概如下

1、首先后缀自动机上寻找匹配 \(t\) 的最长字串, 假设长度为 \(x\) 去掉 ,看 \(2\)

2、在模式串 \(x + 1\) 上的字符 \(+1\) ,比如 将 \(a\) 变成 \(b\) ,看 \(3\)

3、判断新的模式串是否存在文本串的字串匹配长度为 \(x +1\) , 若存在,看 \(4\)

4、判断是否存在对应 \(endpoint\) 的字串,若存在, 则输出当前模式串,若不存在,看 \(5\)

5、若新的字符已经是 \(z\) ,看 \(6\),否则将新 \(x + 1\) 的字符 +1, 比如 将 \(a\) 变成 \(b\) 后,看 \(3\)

6、\(x -= 1\) , 若 \(x == -1\) ,输出 \(-1\), 否则看 \(2\)

某个节点是否存在某个 \(endpoint\) 我们可以用线段树合并维护,具体看代码。

#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <cmath>
using namespace std;
typedef long long LL;
const int maxn = 2e5 + 50;
const LL mod = 1e9 + 7;
double eps = 1e-6;
 
struct state
{
    int len, link;
    int nex[30];
} st[maxn];
 
int sz, last;;
 
void sam_init(){
    st[0].len = 0;
    st[0].link = -1;
    sz = 1;
    last = 0;
}
 
char s[maxn], t[maxn];
int n;
int ed[maxn];
void sam_extend(int x){
    int cur = sz++;
    st[cur].len = st[last].len + 1; 
    int p = last;
    while(p != -1 && !st[p].nex[x]){
        st[p].nex[x] = cur;
        p = st[p].link;
    }
 
    if(p == -1){
        st[cur].link = 0;
    } else {
        int q = st[p].nex[x];
        if(st[p].len + 1 == st[q].len){
            st[cur].link = q;
        } else {
            int clone = sz++;
            st[clone].len = st[p].len + 1;
            for(int i = 0; i < 26; i++){
                st[clone].nex[i] = st[q].nex[i];
            }
            st[clone].link = st[q].link;
            while(p != -1 && st[p].nex[x] == q){
                st[p].nex[x] = clone;
                p = st[p].link;
            }
            st[q].link = st[cur].link = clone;
        }
    }
    last = cur;
}
 
struct qnode
{
    int ls, rs, val;
} tree[maxn * 30];
int tot, root[maxn];
void insert(int le, int ri, int pos, int &rt){
    if(!rt) rt = ++tot;
    tree[rt].val = 1;
    if(le == ri) return ;
    int mid = (le + ri) >> 1;
    if(pos <= mid) insert(le, mid, pos, tree[rt].ls);
    else insert(mid + 1, ri, pos, tree[rt].rs);
}
 
int merge(int u, int v){ // 线段树合并
    if(!u || !v) return u | v;
    int p = ++tot;  //记住要开新节点存合并的线段树,因为后面的查询可能要用到所有节点的线段树
    tree[p].val = tree[u].val | tree[v].val; //由于只用判断该区间存不存在值,所以这样更新就好了
    tree[p].ls = merge(tree[u].ls, tree[v].ls);
    tree[p].rs = merge(tree[u].rs, tree[v].rs);
    return p;
}
 
int Query(int le, int ri, int L, int R, int rt){
    if(L <= le && ri <= R) return tree[rt].val;
    if(!tree[rt].val) return 0;
    int mid = (le + ri) >> 1;
    if(L <= mid && Query(le, mid, L, R, tree[rt].ls)) return 1;
    if(R > mid && Query(mid + 1, ri, L, R, tree[rt].rs)) return 1;
    return 0;    
}
int tax[maxn], id[maxn];
void pre(){ // 根据每个节点的最大长度进行排序,然后从 link 树的下面往上进行线段树合并,因为最大长度越长的节点一定在越靠近link树的根节点
    for(int i = 1; i < sz; i++) tax[st[i].len]++;
    for(int i = 1; i <= n; i++) tax[i] += tax[i - 1];
    for(int i = 1; i < sz; i++) id[tax[st[i].len]--] = i;
    for(int i = 1; i <= n; i++) insert(1, n, i, root[ed[i]]);
    for(int i = sz - 1; i > 1; i--){
        root[st[id[i]].link] = merge(root[st[id[i]].link], root[id[i]]); // 将每个节点合并到它的父节点上
    }
}
 
int pos[maxn];
void solve(){
    int le, ri;
    scanf("%d%d%s", &le, &ri, t + 1);
    int len = strlen(t + 1);
    t[len + 1] = 'a' - 1; // 这句看不懂的先往下看,很容易理解
    int p = 0;
    int mal = 0;
    for(int i = 1; i <= len; i++){
        int x = t[i] - 'a';
        if(st[p].nex[x]){
            p = st[p].nex[x];
            pos[i] = p; // 记录每个位置的 p 
            mal = i;
        } else {
            break;
        }
    }
 
    for(int i = mal; i >= 0; i--){
        int x = t[i + 1] - 'a' + 1;
        while(x < 26){
            p = st[pos[i]].nex[x];
            if(p){
                int res = Query(1, n, le + i, ri, root[p]);
                if(res) {
                    for(int j = 1; j <= i; j++){
                        printf("%c", t[j]);
                    }
                    printf("%c\n", x + 'a');
                    return ;
                }
            }
            x++;
        }
    }
    printf("-1\n");
}
int main()
{
    sam_init();
    scanf("%s", s + 1);
    n = strlen(s + 1);
    for(int i = 1; i <= n; i++){
        int x = s[i] - 'a';
        sam_extend(x);
        ed[i] = last; // 记录文本串每个前缀对应的节点,因为所有节点的endpoint,都是从这些点转移过来的
    }
    pre();
    int q;
    scanf("%d", &q);
    while(q--){
        solve();
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/PCCCCC/p/13389540.html