[NOI Online 2021 提高组] 积木小赛——后缀自动机+子序列自动机

同步于:https://www.luogu.com.cn/blog/OUYE2020/solution-p7469

题解

考虑字符串T中的本质不同的子串最多 n 2 n^2 n2 个,暴力枚举判断至少是 O ( n 3 ) O(n^3) O(n3) 的,用 h a s h hash hash 可以优化到 O ( n 2 ) O(n^2) O(n2) 。然而为了保证正确性,我们不用 h a s h hash hash (其实我做字符串题就只用过一次这玩意

那么为了去重,最好的方法是对 T T T 串建立一个后缀自动机,按边遍历整个后缀自动机就可以像走 t r i e trie trie 树一样得到每个子串的信息。由于后缀自动机上最多 2 n 2n 2n 个点,每个点会遍历不超过 n n n 次(其实比 n n n 小得多),所以遍历一遍均摊复杂度 O ( n 2 ) O(n^2) O(n2) ,在 n > 26 n>26 n>26 的情况下严格小于 O ( n 2 ) O(n^2) O(n2)

怎么拿 T T T 中的子串和 S S S 中的子序列匹配呢?我们贪心地记录 f i , c f_{i,c} fi,c 为第 i i i 个字符往后第一次出现字符 c c c 的位置,把 i i i f i , c f_{i,c} fi,c 之间连边,
最终可得到一颗最多 n ∗ 26 n*26 n26 条边的 t r i e trie trie 树。根据贪心思路,这棵树必定记录了 S S S 的所有子序列的信息,这颗 t r i e trie trie 树就叫 S S S子序列自动机

然后我们只需要同步遍历两个自动机,求出所有重合的节点数就是答案。这个复杂度为两个自动机遍历次数取较小值,所以总复杂度小于 O ( n 2 ) O(n^2) O(n2)

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#include<queue>
#include<stack>
#define ll long long
#define MAXN 3005
#define uns unsigned
#define INF 0x3f3f3f3f
using namespace std;
inline ll read(){
    
    
	ll x=0;bool f=1;char s=getchar();
	while((s<'0'||s>'9')&&s>0){
    
    if(s=='-')f^=1;s=getchar();}
	while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+s-'0',s=getchar();
	return f?x:-x;
}
struct SAM{
    
    
	int ch[26],len,fa;
	SAM(){
    
    memset(ch,0,sizeof(ch)),len=fa=0;}
}sam[MAXN<<1];
int las=1,tot=1;
inline void samadd(int c){
    
    
	int p=las,np=las=++tot;sam[np].len=sam[p].len+1;
	for(;p&&sam[p].ch[c]==0;p=sam[p].fa)sam[p].ch[c]=np;
	if(!p)sam[np].fa=1;
	else{
    
    int q=sam[p].ch[c];
		if(sam[q].len==sam[p].len+1)sam[np].fa=q;
		else{
    
    
			int nq=++tot;sam[nq]=sam[q],sam[nq].len=sam[p].len+1,sam[q].fa=sam[np].fa=nq;
			for(;p&&sam[p].ch[c]==q;p=sam[p].fa)sam[p].ch[c]=nq;
		}
	}
}
int n,ans;
char a[MAXN],b[MAXN];
int tr[MAXN][26];
inline void dfs(int x,int y){
    
    
	if(!x||!y)return;
	if(x>1)ans++;
	for(int i=0;i<26;i++)
		dfs(sam[x].ch[i],tr[y][i]);
}
signed main()
{
    
    
//	freopen("block.in","r",stdin);
//	freopen("block.out","w",stdout);
	n=read();
	scanf("%s\n%s",a+2,b+1);
	for(int i=n;i>0;i--){
    
    
		for(int j=0;j<26;j++)tr[i][j]=tr[i+1][j];
		tr[i][a[i+1]-'a']=i+1;
	}
	for(int i=1;i<=n;i++)samadd(b[i]-'a');
	dfs(1,1);
	printf("%d\n",ans);
	return 0;
}

猜你喜欢

转载自blog.csdn.net/weixin_43960287/article/details/115268627