AC自动机一种将 Trie 树和 KMP 相结合的算法。AC 自动机有以下三个过程:建立 Trie 树、建立失败指针、字符串匹配。
建立 Trie 树。这一步操作和 Trie 树的一样,将若干个模式串建立起一棵 Trie 树。
建立失败指针。这一步类似于 KMP 算法中建立 next 数组,同学们记得 KMP 中提到的失败指针 next 作用么?这是为了方便后续的匹配操作,当第 i 位失配时,则跳转到 next[i] 位进行匹配。在 AC 自动机中也是同样的功能,在某个结点 A 上失配时,则通过失败指针指向某结点 B 进行下一次的匹配(定义以结点 A 为终点的某个后缀为 s1;以根结点为起点,结点 B 为终点的串为 s2,则有 s1=s2,且保证在所有这样的串中, s1 和 s2 的长度是最大的),而不需要回溯到上一层,避免多余的匹配操作。
字符串匹配。从根结点开始,沿树的路径进行匹配,如果当前结点匹配成功则继续往下一个结点匹配,否则跳转到失败指针所指的结点进行匹配。重复上述过程,直到匹配完模式串为止。
我们可以用广度优先搜索的方法构造失败指针。每个结点都有一个失败指针,首先将根结点的失败指针指向空,根结点的直接子结点的失败指针指向根结点。对 Trie 树进行广度优先搜索,每个结点的失败指针都是由它父结点的失败指针决定的。例如,字符串she是从sh连向she的,sh的失败指针指向h,所以she的失败指针就指向”h” + “e” = “he”。若不存在he,则失败指针也就会指向更上方。
以下给出两份模板代码
第一个
const int MAX_N = 10000;
const int MAX_C = 26;
struct AC_Automaton {
int ch[MAX_N][MAX_C], fail[MAX_N], cnt[MAX_N]; // ch 和 cnt 数组与 Trie 树中的一样;fail 保存的是失败指针。ch 和 fail 默认都为 -1
int tot; // Trie 树的总结点(不含根结点)个数
void init() {
memset(ch, -1, sizeof(ch));
memset(fail, 0, sizeof(fail));
tot = 0;
memset(cnt, 0, sizeof(cnt));
}
void insert(char* str) {
int p = 0;
for (int i = 0; str[i]; ++i) {
if (ch[p][str[i] - 'a'] == -1) {
ch[p][str[i] - 'a'] = ++tot;
}
p = ch[p][str[i] - 'a'];
}
cnt[p]++;
}
void build() {
int l = 0, r = 0, Q[MAX_N];
for (int i = 0; i < MAX_C; i++) {
if (ch[0][i] == -1) {
ch[0][i] = 0;
} else {
Q[r++] = ch[0][i];
}
}
while (l < r) {
int p = Q[l++];
for (int i = 0; i < MAX_C; i++) {
if (ch[p][i] == -1) {
ch[p][i] = ch[fail[p]][i];
} else {
fail[ch[p][i]] = ch[fail[p]][i];
Q[r++] = ch[p][i];
}
}
}
}
int count(char* str) { // 统计一个字符串在给定的字符串集合(Trie)中出现了多少次
int ret = 0, p = 0;
for (int i = 0; str[i]; ++i) {
p = ch[p][str[i] - 'a'];
int tmp = p;
while (tmp) {
ret += cnt[tmp];
cnt[tmp] = 0; // 避免重复统计同一字符串
tmp = fail[tmp];
}
}
return ret;
}
};
第二个
#include <iostream>
#include <cstring>
using namespace std;
const int MAX_N = 10000;
const int MAX_C = 26;
struct AC_Automaton {
int ch[MAX_N][MAX_C], fail[MAX_N], cnt[MAX_N];
int tot;
void init() {
memset(ch, -1, sizeof(ch));
memset(fail, 0, sizeof(fail));
tot = 0;
memset(cnt, 0, sizeof(cnt));
}
void insert(char* str) {
int p = 0;
for (int i = 0; str[i]; ++i) {
if (ch[p][str[i] - 'a'] == -1) {
ch[p][str[i] - 'a'] = ++tot;
}
p = ch[p][str[i] - 'a'];
}
cnt[p]++;
}
void build() {
int l = 0, r = 0, Q[MAX_N];
for (int i = 0; i < MAX_C; i++) {
if (ch[0][i] == -1) {
ch[0][i] = 0;
} else {
Q[r++] = ch[0][i];
}
}
while (l < r) {
int p = Q[l++];
for (int i = 0; i < MAX_C; i++) {
if (ch[p][i] == -1) {
ch[p][i] = ch[fail[p]][i];
} else {
fail[ch[p][i]] = ch[fail[p]][i];
Q[r++] = ch[p][i];
}
}
}
}
int count(char* str) {
int ret = 0, p = 0;
for (int i = 0; str[i]; ++i) {
p = ch[p][str[i] - 'a'];
int tmp = p;
while (tmp) {
ret += cnt[tmp];
cnt[tmp] = 0;
tmp = fail[tmp];
}
}
return ret;
}
} ac;
int main() {
ac.init();
ac.insert("abcd");
ac.insert("bcd");
ac.insert("cd");
ac.insert("d");
ac.build();
cout << ac.count("abcd") << endl;
return 0;
}