字符串-KMP
作用:在一个文本字符串中找模式字符串出现次数、位置。
前缀知识:
字符串。
算法名字来源:发明人
Knuth(D.E.Knuth)&Morris(J.H.Morris)&Pratt(V.R.Pratt)。
讲解:
比如要在文本字符串
a=ababaababaabab 中找模式字符串
b=abaabab,暴力的做法就是枚举
a[i]==b[1],然后对
a[i∼i+len(b)−1] 和
b[1∼len(b)] 进行匹配,代码:
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10;
int n,m,ans;
char a[N],b[N];
int main(){
scanf("%s%s",a+1,b+1);
n=strlen(a+1),m=strlen(b+1);
for(int i=1;i<=n-m+1;i++)
if(a[i]==b[1]){
bool ok=1;
for(int j=2;j<=m;j++)
if(a[i+j-1]!=b[j]){ok=0;break;}
if(ok) ans++;
}
printf("%d\n",ans);
return 0;
}
时间复杂度为
Θ(n×m),爆率百分百。而
Θ(n+m) 的KMP的精华就在于,每次上面代码标记的那行失配(匹配失败,
a[i+j−1]!=b[j])以后,不需要让模式串
b 从头开始匹配,而是跳到一个固定的位置,开始匹配。
如下,灰色表示待匹配,绿色表示正在匹配(成功),红色表示正在匹配(失败),黑色表示已经匹配:
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
文本串和模式串失配,不需要如下让模式串
b 从头开始匹配:
ababaababaabab
abaabab ←错误示范
而是应该这样:
ababaababaabab
abaabab
这时我能感受到你诧异的表情,这不是玄学穿越,而是有依据的。对于模式串
b 成功匹配的前三个字符
aba,满足该字符串最多前
1 个字符等于后
1 个字符,而前
2 个字符就不等于后
2 个字符了。所以这时,就可以知道两点:
1.
b 的前
1 个字符能和
a 的第
3∼3 个字符匹配。
2.如果把
b 的第
1 个字符对
a 的第
2∼2 个字符,必将不会整个匹配成功。
所以根据
b 前
3 个字符组成的子串中最多前几个字符等于后几个字符,就可以得出失配后跳转的方法。为了更全面具体的解说,看如下继续匹配:
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
如上,成功发现了一个模式串
b 在文本串
a 中出现的位置。这时候就不能在再沿着
b 继续匹配下去了,所以也可以看作是失配。因为对于字符串
b 的成功匹配的前
7 个字符组成的字符串,满足前两个字符等于后两个字符等于
ab,所以这么跳转匹配:
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
ababaababaabab
abaabab
然后又发现一个模式串
b 在文本串
a 中出现的位置,并且所有
a 的所有字符都已经匹配结束,所以结束匹配。最终得出,
b 在
a 中出现了
2 次,两次中
b 的第一个字符分别对应
a 的第
3 个和第
8 个字符。
所以如果我们现在已经有数组
nex[x] 表示
b 的前
x 个字符所组成的字符串中,最多前
nex[x] 个字符与后
nex[x] 个字符完全一样(
0≤nex[x]<x),那么匹配的代码就可以这么写:
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10;
class charstar{
public:char arr[N];
int len;
char& operator[](int x){return arr[x];}
void leng(){len=strlen(arr+1);}
}a;
class KMP:public charstar{
public:
int nex[N];
void build(){
}
void found(charstar&book,queue<int>&q){
for(int i=1,j=0;i<=book.len;i++){
while(j&&book[i]!=arr[j+1]) j=nex[j];
if(book[i]==arr[j+1]) j++;
if(j==len) q.push(i-len+1),j=nex[j];
}
}
}b;
queue<int> ans;
int main(){
scanf("%s%s",&a[1],&b[1]);
a.leng(),b.leng();
b.build(),b.found(a,ans);
while(ans.size()) printf("%d\n",ans.front()),ans.pop();
for(int i=1;i<=b.len;i++) printf("%d%c",b.nex[i],"\n "[i<b.len]);
return 0;
}
这样的算法时间复杂度是
Θ(n+m) 的,为了保证复杂度,求
nex[] 数组也必须
Θ(m+m)。聪明的三个科学家想到了一个很微妙的方法——
b 自己匹配自己。这就难解释了,放代码:
void build(){
for(int i=2,j=0;i<=len;i++){
while(j&&arr[j+1]!=arr[i]) j=nex[j];
if(arr[j+1]==arr[i]) j++;
nex[i]=j;
}
}
和上面的匹配几乎一模一样。
如果你懂了,蒟蒻就放代码了:
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10;
class charstar{
public:char arr[N];
int len;
char& operator[](int x){return arr[x];}
void leng(){len=strlen(arr+1);}
}a;
class KMP:public charstar{
public:
int nex[N];
void build(){
for(int i=2,j=0;i<=len;i++){
while(j&&arr[j+1]!=arr[i]) j=nex[j];
if(arr[j+1]==arr[i]) j++;
nex[i]=j;
}
}
void found(charstar&book,queue<int>&q){
for(int i=1,j=0;i<=book.len;i++){
while(j&&book[i]!=arr[j+1]) j=nex[j];
if(book[i]==arr[j+1]) j++;
if(j==len) q.push(i-len+1),j=nex[j];
}
}
}b;
queue<int> ans;
int main(){
scanf("%s%s",&a[1],&b[1]);
a.leng(),b.leng();
b.build(),b.found(a,ans);
while(ans.size()) printf("%d\n",ans.front()),ans.pop();
for(int i=1;i<=b.len;i++) printf("%d%c",b.nex[i],"\n "[i<b.len]);
return 0;
}
如果你看不惯这种匹配双重循环的版本,另一个版本:
#include <bits/stdc++.h>
using namespace std;
const int N=1e6+10;
class charstar{
public:char arr[N];
int len;
char& operator[](int x){return arr[x];}
void leng(){len=strlen(arr+1);}
}s1;
class KMP:public charstar{
public:
int nex[N];
void build(){
for(int i=1,j=0;i<=len;)
if(!j||arr[i]==arr[j]) nex[++i]=++j;
else j=nex[j];
}
void found(charstar&book,queue<int>&q){
for(int i=1,j=1;i<=book.len;){
if(!j||book[i]==arr[j]) i++,j++;
else j=nex[j];
if(j==len+1) q.push(i-len),j=nex[j];
}
}
}s2;
queue<int> ans;
int main(){
scanf("%s%s",&s1[1],&s2[1]);
s1.leng(),s2.leng();
s2.build(),s2.found(s1,ans);
while(ans.size()) printf("%d\n",ans.front()),ans.pop();
for(int i=1;i<=s2.len;i++) printf("%d%c",s2.nex[i+1]-1,"\n "[i<s2.len]);
return 0;
}
字符串学习之路(
★ 表示当前学习知识):
hash-kmp★-manacher-exkmp-trie-acam-sa-sam-pam
祝大家学习愉快!