这篇博客讲的非常到位:博客地址
需要注意的几个点:
1.将trie树每个结点转化为一个状态,M[i][j]表示从状态i走一步到状态j的方法数
2.使用到了矩阵快速幂
3.若一点fail指向的节点有毒,该点也有毒,因为fail指向字符串是该点字符串的后缀,该部分可以使用深度优先搜索完成。
4.将带毒节点贡献置零
#include<iostream>
#include<string.h>
#include<queue>
using namespace std;
const int N = 105;
const int mod = 100000;
typedef long long ll;
int trie[N][4];
int fail[N];
int poison[N];
int index = 0;
struct Matrix{
ll m[N][N];
Matrix(int flag = 0){
memset(m,0,sizeof(m));
if(flag)
for(int i = 0;i <= index;i++)
m[i][i] = 1;
}
Matrix(const Matrix& mat){
for(int i = 0;i <= index;i++)
for(int j = 0;j <= index;j++)
m[i][j] = mat.m[i][j]%mod;
}
Matrix operator *(const Matrix & m1){
Matrix mat;
for(int i = 0;i <= index;i++){
for(int j = 0;j <= index;j++){
for(int k = 0;k <= index;k++){
mat.m[i][j] = ((m[i][k] * m1.m[k][j])%mod + mat.m[i][j])%mod;
}
}
}
return mat;
}
void show(){
for(int i = 0;i <= index;i++){
for(int j = 0;j <= index;j++)
cout<<m[i][j]<<" ";
cout<<endl;
}
}
};
Matrix fast_pow(Matrix mat,ll n){
Matrix ans(1);
while(n){
if(n&1)ans = ans * mat;
n>>= 1;
mat = mat * mat;
}
return ans;
}
int get(char c){
if(c == 'A')return 0;
if(c == 'C')return 1;
if(c == 'T')return 2;
if(c == 'G')return 3;
}
void insert(string s){
int p = 0,len = s.size();
for(int i = 0;i < len;i++){
int x = get(s[i]);
if(!trie[p][x])trie[p][x] = ++index;
p = trie[p][x];
}
poison[p] = 1;
}
void construct_fail(){
queue<int> q;
for(int i = 0;i < 4;i++){
if(trie[0][i]){
fail[trie[0][i]] == 0;
q.push(trie[0][i]);
}
}
while(!q.empty()){
int p = q.front();
q.pop();
for(int i = 0;i < 4;i++){
int x = trie[p][i];
if(x){
fail[x] = trie[fail[p]][i];
q.push(x);
}else
trie[p][i] = trie[fail[p]][i];
}
}
}
int dfs(int x){
if(x == 0)
return 0;
if(poison[x] == 1)
return 1;
return dfs(fail[x]);
}
void solve(ll n){
ll ans = 0;
for(int i = 0;i <= index;i++)
if(!poison[i])
poison[i] = dfs(fail[i]);
Matrix mat;
for(int i = 0;i <= index;i++){
for(int j = 0;j < 4;j++){
int x = trie[i][j];
if(!poison[x] && !poison[i])mat.m[i][x]++;
else mat.m[i][x] = 0;
}
}
mat = fast_pow(mat,n);
for(int i = 0;i <= index;i++)
ans = (ans + mat.m[0][i]) % mod;
cout<<ans<<endl;
}
ll n;
int m;
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin>>m>>n;
for(int i = 0;i < m;i++){
string s;
cin>>s;
insert(s);
}
construct_fail();
solve(n);
return 0;
}
忙活了一晚上ac了,我太难了!