在学习AC自动机前,请确保你已经充分理解
KMP算法 ANDTrie字典树
我们将从这样一个问题开始引入AC自动机
Q:给定n个模式串和1个文本串,求有多少个模式串在文本串中出现过
这个问题要怎么解?
用N次KMP吗,这样显然爆炸啊
于是闲着没事干脑袋又十分丰腴的科学家们有了一个奇妙的想法
在Trie上求KMP!(当然实际上只是类似KMP的nxt,定义还是有所不同的)
假设当前有5个模式串’she’, ‘he’, ‘say’, ‘shr’, ‘her’
先建出他们的字典树
建好字典树后我们效仿KMP的nxt数组
在Trie上增加fail失配指针
什么是fail指针
假设当前结点
所代表的串为
,那么
的
指针指向
最长的,能与
的后缀匹配的的Trie树的前缀 的结尾结点
(这都什么 #$*&@%¥^#)
是不是有点被绕晕了,那就看这图感性理解一下吧
比如最长的,能与串sh的后缀匹配的 Trie的前缀,只有串h
以及最长的,能与串she的后缀匹配的 Trie的前缀,只有串he
那么这个fail指针要怎么求呢
可以考虑用BFS实现
假设当前从队首取出结点
对于
的一个子节点
我们从
开始不断沿着
指针向上跳
直到跳到一个结点
也有表示字符
的子节点
那么
的
指针指向
特别的,如果一直跳到根都没有符合条件的结点
那么
的
指针指向根
以及注意所有第二层的结点
指针都指向根
void build_AC()
{
for(int i=0;i<=25;++i)
if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);//第二层节点fail都指向根
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<=25;++i)
{
if(!ch[u][i]) continue;//没有这个子节点就跳过
int tt=fail[u];
while(!ch[tt][i]&&tt) tt=fail[tt];//沿着fail指针找到第一个也有同样子节点的结点
fail[ch[u][i]]=ch[tt][i];
q.push(ch[u][i]);
}
}
}
现在连好了fail指针,匹配就简单了
首先用一个指针指向根
将文本串一位一位送入自动机
若当前指针存在表示文本串下一位的子节点,令指针移向该子节点
否则沿着fail指针不断转移,直到跳到一个存在该子节点的结点,令指针移向该子节点
指针没跳转完成一次,就沿着fail指针统计一次
void query(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
while(!ch[u][x]&&u) u=fail[u];
u=ch[u][x];
for(int t=u;t&&sum[t]!=-1;t=fail[t])
ans+=sum[t],sum[t]=-1;
}
}
AC自动机の应用
HDU - 2222 Keywords Search
上述问题的果题
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return x*f;
}
const int maxn=500010;
int Q,n,cnt;
char pat[maxn],txt[maxn<<1];
int ch[maxn][26],fail[maxn],sum[maxn];
queue<int> q;
int ans;
void ins(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
if(!ch[u][x]) ch[u][x]=++cnt;
u=ch[u][x];
}
sum[u]++;
}
void build_AC()
{
for(int i=0;i<=25;++i)
if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<=25;++i)
{
if(!ch[u][i]) continue;
int tt=fail[u];
while(!ch[tt][i]&&tt) tt=fail[tt];
fail[ch[u][i]]=ch[tt][i];
q.push(ch[u][i]);
}
}
}
void query(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
while(!ch[u][x]&&u) u=fail[u];
u=ch[u][x];
for(int t=u;t&&sum[t]!=-1;t=fail[t])
ans+=sum[t],sum[t]=-1;
}
}
void init()
{
ans=cnt=0;
memset(sum,0,sizeof(sum));
memset(ch,0,sizeof(ch));
}
int main()
{
Q=read();
while(Q--)
{
n=read(); init();
for(int i=1;i<=n;++i)
{
scanf("%s",&pat);
ins(pat,strlen(pat));
}
scanf("%s",&txt);
build_AC(); query(txt,strlen(txt));
printf("%d\n",ans);
}
return 0;
}
P3796 【模板】AC自动机(加强版)
Q:有N个由小写字母组成的模式串以及一个文本串T。每个模式串可能会在文本串中出现多次。你需要找出哪些模式串在文本串T中出现的次数最多。
也是稍作修改即可的果题
#include<iostream>
#include<cstdio>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
using namespace std;
int read()
{
int f=1,x=0;
char ss=getchar();
while(ss<'0'||ss>'9'){if(ss=='-')f=-1;ss=getchar();}
while(ss>='0'&&ss<='9'){x=x*10+ss-'0';ss=getchar();}
return x*f;
}
const int maxn=50010;
int n;
char pt[200][100],txt[maxn*20];
int ch[maxn][26],fail[maxn],cnt;
int id[maxn],num[200];
queue<int> q;
int ans;
void ins(char *ss,int len,int k)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
if(!ch[u][x]) ch[u][x]=++cnt;
u=ch[u][x];
}
id[u]=k;
}
void build_AC()
{
for(int i=0;i<=25;++i)
if(ch[0][i]) fail[ch[0][i]]=0,q.push(ch[0][i]);
while(!q.empty())
{
int u=q.front(); q.pop();
for(int i=0;i<=25;++i)
{
if(!ch[u][i]) continue;
int tt=fail[u];
while(!ch[tt][i]&&tt) tt=fail[tt];
fail[ch[u][i]]=ch[tt][i];
q.push(ch[u][i]);
}
}
}
void query(char *ss,int len)
{
int u=0;
for(int i=0;i<len;++i)
{
int x=ss[i]-'a';
while(!ch[u][x]&&u) u=fail[u];
u=ch[u][x];
for(int t=u;t;t=fail[t])
num[id[t]]++;
}
for(int i=1;i<=n;++i)
ans=max(ans,num[i]);
}
void init()
{
ans=cnt=0;
memset(ch,0,sizeof(ch));
memset(id,0,sizeof(id));
memset(num,0,sizeof(num));
}
int main()
{
while(scanf("%d",&n)!=EOF)
{
if(n==0) break; init();
for(int i=1;i<=n;++i)
{
scanf("%s",&pt[i]);
ins(pt[i],strlen(pt[i]),i);
}
scanf("%s",&txt);
build_AC(); query(txt,strlen(txt));
printf("%d\n",ans);
for(int i=1;i<=n;++i)
if(num[i]==ans) printf("%s\n",pt[i]);
}
return 0;
}