SPOJ LCS Longest Common Substring(后缀自动机)题解

题意:

求两个串的最大\(LCS\)

思路:

把第一个串建后缀自动机,第二个串跑后缀自动机,如果一个节点失配了,那么往父节点跑,期间更新答案即可。

代码:

#include<set>
#include<map>
#include<cmath>
#include<queue>
#include<bitset>
#include<string>
#include<cstdio>
#include<vector>
#include<cstring>
#include <iostream>
#include<algorithm>
using namespace std;
const int maxn = 500000  + 10;
typedef long long ll;
const ll mod = 998244353;
typedef unsigned long long ull;

struct SAM{
    struct Node{
        int next[27];   //下一节点
        int fa, maxlen;//后缀链接,当前节点最长子串
        void init(){
            memset(next, 0, sizeof(next));
            fa = maxlen = 0;
        }
    }node[maxn << 1];
    int sz, last;

    void init(){
        sz = last = 1;
        node[sz].init();
    }
    void insert(int k){
        int p = last, np = last = ++sz;
        node[np].init();
        node[np].maxlen = node[p].maxlen + 1;
        for(; p && !node[p].next[k]; p = node[p].fa)
            node[p].next[k] = np;
        if(p == 0) {
            node[np].fa = 1;
        }
        else{
            int t = node[p].next[k];
            if(node[t].maxlen == node[p].maxlen + 1){
                node[np].fa = t;
            }
            else{
                int nt = ++sz;
                node[nt] = node[t];
                node[nt].maxlen = node[p].maxlen + 1;
                node[np].fa = node[t].fa = nt;
                for(; p && node[p].next[k] == t; p = node[p].fa)
                    node[p].next[k] = nt;
            }
        }
    }

    void solve(char *s){
        int ans = 0, ret = 0;
        int len = strlen(s);
        int pos = 1;
        for(int i = 0; i < len; i++){
            int c = s[i] - 'a';
            while(pos && node[pos].next[c] == 0){
                pos = node[pos].fa;
                ret = node[pos].maxlen;
            }
            if(pos == 0){
                pos = 1;
                ret = 0;
            }
            else{
                pos = node[pos].next[c];
                ret++;
            }
            ans = max(ans, ret);
        }
        printf("%d\n", ans);
    }
}sam;
char s[maxn];
int main(){
    sam.init();
    scanf("%s", s);
    int len = strlen(s);
    for(int i = 0; i < len; i++) sam.insert(s[i] - 'a');
    scanf("%s", s);
    sam.solve(s);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/KirinSB/p/11667221.html