概况
AC自动机,全名Aho-Corasick自动机。是一种字符串匹配算法。在学AC自动机之前,必须知道字典树Trie和字符串匹配算法KMP。KMP算法我写过一篇博客,可以点击这里学习,而对于Trie算法,大家可以去网上搜索。接下去讲AC自动机的时候就默认会Trie和KMP了。
AC自动机原理
我们都知道KMP算法可以高效地将模式串与文本串匹配。但是如果我们有许多模式串需要进行匹配时,就会遇到麻烦。因为用KMP每次查找一个模板,都得遍历整个文本串。可不可以只遍历一次呢?答案是可行的。方法就是将所有模式串建成一个大的状态转移图,这个图就叫AC自动机。由于KMP的转移是由线性的字符和失配next数组组成的,所以我们不难想象AC自动机也要加上失配边。
整个AC自动机的框架是由Trie构成的,我们将所有模式串构成一棵字典树,如果我们能找到类似KMP里的next数组的对应关系,那么AC自动机也就讲完了,那么我们应该怎么构筑呢?
计算失配数组fail
我们举个例子,假设我们已经用单词ABAB,BA,BB构成了一棵Trie,如图:
。
通过这样的构筑,我们可以保证x失配后往fail[x]跳后前缀是一样的,所以可以继续往下匹配。
匹配
构筑完Trie和fail数组后,就可以开始匹配了。在匹配时我们只要遵循:
1.如果失配,则当前位置p变为fail[p],继续匹配,直至匹配到根。
2.如果匹配成功,若有以当前点结尾的模式串,则统计答案,并跳到fail[p]继续查看是否有模式串;如果没有,直接匹配文本串里的下一个字符。
比如我们要匹配文本串BBABAA。
首先B往节点6匹配,第二个字符也是B,所以匹配到节点8,节点8存在一起为结尾的模式串,所以统计答案,并一路fail到根(此时节点p依旧在8)。然后是字符A,失配了,于是跳到8的fail6,继续匹配,匹配到7,刚好匹配,继续统计答案。有一路匹配到根。接下去B字符,优失配,所以跳到A的fail点2上,与2匹配,接下去B也匹配,一直匹配到点5。然而5为B,但文本串最后一位是A,失配了,所以一直往上fail到根。结束匹配。
模板
因此我刚开始打了个递归的模板。
#include<bits/stdc++.h>
using namespace std;
int n,cnt=1,p=1,temp,ans;
string s,k;
struct node{
int sum,vis,fail,next[26];
}F[500005];
void build(string s){
int pl=1;
for(int i=0;i<s.length();i++){
char c=s[i];
if(F[pl].next[c-'a']) pl=F[pl].next[c-'a'];
else{
cnt++;F[pl].next[c-'a']=cnt;
pl=cnt;
}
}
F[pl].sum++;
}
int find(int num,int pl){
if(F[pl].next[num]) return F[pl].next[num];
else return pl==1?1:find(num,F[pl].fail);
}
void fail(int x){
for(int i=0;i<26;i++){
if(!F[x].next[i]) continue;
if(x==1) F[F[x].next[i]].fail=1;
else F[F[x].next[i]].fail=find(i,F[x].fail);
fail(F[x].next[i]);
}
}
void match(int num){
if(F[p].next[num]){
p=F[p].next[num];
if(F[p].sum&&!F[p].vis){
F[p].vis=1;ans+=F[p].sum;temp=F[p].fail;
while(temp!=1){
if(F[temp].sum&&!F[temp].vis){
F[temp].vis=1;ans+=F[temp].sum;
}
temp=F[temp].fail;
}
}
return;
}
else{
p=F[p].fail;if(p!=1) match(num);
}
}
int main()
{
scanf("%d",&n);F[1].fail=1;
for(int i=1;i<=n;i++){
cin>>s;build(s);
}
fail(1);
cin>>k;
for(int i=0;i<k.length();i++){
temp=0;
match(k[i]-'a');
}
printf("%d",ans);
return 0;
}
然而实际上还是可以进行优化,我们在构筑fail数组时,有时得跳到它父亲的fail的fail的fa行优化,我们在构筑fa
il数组时,有时得跳到它父亲的fail的fail的fail……,所以我们可以直接将最后出现的点记录在字典树上,即:
F[x].next[i]=F[F[x].fail].next[i]
所以有了如下代码(非递归版,内存消耗较小)
#include<bits/stdc++.h>
using namespace std;
int n,cnt=1,p=1,temp,ans,q[1000005];
string s,k;
struct node{
int sum,vis,fail,next[26];
}F[1000005];
int read(){
char c;int x;while(c=getchar(),c<'0'||c>'9');x=c-'0';
while(c=getchar(),c>='0'&&c<='9') x=x*10+c-'0';return x;
}
void build(string s){
int pl=1;
for(int i=0;i<s.length();i++){
char c=s[i];
if(F[pl].next[c-'a']) pl=F[pl].next[c-'a'];
else{
cnt++;F[pl].next[c-'a']=cnt;
pl=cnt;
}
}
F[pl].sum++;
}
void fail(){
int h=0,t=0;
for(int i=0;i<26;i++) if(F[1].next[i]) F[q[++t]=F[1].next[i]].fail=1;
while(h<t){
int pl=q[++h];
for(int i=0;i<26;i++)
if(F[now].next[i]){
F[q[++t]=F[now].next[i]].fail=F[F[now].fail].next[i];
if(!F[F[now].next[i]].fail)F[F[now].next[i]].fail=1;
}
else F[now].next[i]=F[F[now].fail].next[i];
}
}
void match(){
for(int i=0;i<k.length();i++){
if(!p) p=1;
p=F[p].next[k[i]-'a'];
for(temp=p;temp>1&&F[temp].sum!=-1;temp=F[temp].fail) ans+=F[temp].sum,F[temp].sum=-1;
}
printf("%d",ans);
}
int main()
{
n=read();F[1].fail=1;
for(int i=1;i<=n;i++){
cin>>s;build(s);
}
fail();
cin>>k;
match();
return 0;
}
一道AC自动机的例题
这道题是一道AC自动机的模板题,因为我么要统计出现了几次,所以每次匹配过之后不将其赋值为-1。而是新开一个tim,变量,每访问一次,就tim++。最后统计最大的tim即可。
#include<bits/stdc++.h>
using namespace std;
int n,p=1,temp,cnt=1,q[50000],maxnum;
struct node{
int num,sum,fail,tim,fa,next[26];
}F[50000];
string s,k;
void build(string s){
int pl=1;
for(int i=0;i<s.length();i++){
if(F[pl].next[s[i]-'a']) pl=F[pl].next[s[i]-'a'];
else{
cnt++;F[cnt].num=s[i]-'a';F[cnt].fa=pl;
F[pl].next[s[i]-'a']=cnt;pl=cnt;
}
}
F[pl].sum++;
}
void fail(){
int h=0,t=0;
for(int i=0;i<26;i++) if(F[1].next[i]) F[q[++t]=F[1].next[i]].fail=1;
while(h<t){
int now=q[++h];
for(int i=0;i<26;i++)
if(F[now].next[i]){
F[q[++t]=F[now].next[i]].fail=F[F[now].fail].next[i];
if(!F[F[now].next[i]].fail)F[F[now].next[i]].fail=1;
}
else F[now].next[i]=F[F[now].fail].next[i];
}
}
void match(string k){
for(int i=0;i<k.length();i++){
if(!p) p=1;p=F[p].next[k[i]-'a'];
for(temp=p;temp;temp=F[temp].fail)
if(F[temp].sum>0) F[temp].tim++,maxnum=max(maxnum,F[temp].tim);
}
}
void print(int pl,int times){
int sta[105],top=0;
for(int i=pl;i>1;i=F[i].fa) sta[++top]=F[i].num;
string q;for(int i=top;i>=1;i--) q.push_back(sta[i]+'a');
while(times){
times--;cout<<q<<endl;
}
}
int main()
{
while(~scanf("%d",&n)){
if(!n) break;
memset(F,0,sizeof(F));
maxnum=0;p=1;cnt=1;
for(int i=1;i<=n;i++){
cin>>s;build(s);
}
fail();
cin>>k;
match(k);
printf("%d\n",maxnum);
for(int i=1;i<=cnt;i++)
if(F[i].tim==maxnum) print(i,F[i].sum);
}
return 0;
}