POJ - 2778 DNA Sequence (AC自动机+矩阵快速幂+图论)

题意:给你N个模式串( 只含 'A' , 'T' , 'G' , 'C' )以及一个文本串长度m,每个串的长度不会超过10,求出有多少个串满足不包含所有的模式串,并且长度为m. 结果 mod 100000

思路:我们要使组成的串不包含模式串,如果从AC自动机的角度来想就是每一次都不能走到某个模式串结尾的位置(结点)同时他的下一个位置也不能是某个模式串的结尾,不然就不符合条件。

然后我们把符合条件的结点。以及能够互相达到的结点用邻接矩阵来存。在根据图论(下面有将)中有关邻接矩阵幂的性质来求得以 a 开始 b 结尾(长为m)的方案个数。最后我们对每个结果求和即可。其中有关矩阵乘法,我们可以用矩阵快速幂来求得。

(图论)邻接矩阵的幂:在图论中,我们可以用0,1的邻接矩阵表示图中边集,或者说点集中两两间的连通性。

设A是某个图的邻接矩阵
考虑矩阵乘法的定义:
\[令\space C=A \times B,\space\space即C_{ij}=\sum\limits_{k=1}^nA_{ik}\times B_{kj}\]
\[则A^2_{ij}=\sum\limits_{k=1}^nA_{ik}\times A_{kj}\]
邻接矩阵\(A\)中的元素都是用\(0,1\)来表示是否联通的,或者说,代表有没有方法从i走到j。那么,\(A_{(i,j)}×A_{(k,j)}\)就是表示从\(i\) 走到 \(k\) 再走到 \(j\) 是否可行。可以发现,\(A^2\)就是取了一个 \(\sum\) ,即统计用2步从i走到\(j\)的方法总数。
考虑累乘的效果,矩阵\(A_m\)所代表的意义就是从点与点之间走\(m\)步能够到达的方案总数

code:

#include <algorithm>
#include <iostream>
#include <cstring>
#include <vector>
#include <cstdio>
#include <string>
#include <cmath>
#include <queue>
#include <set>
#include <map>
#include <complex>
#include<stack>
#define ll long long
using namespace std;
const int N = 250;
const int maxn = 1e3;
const int mod = 100000;
queue<int>q;
int mapp[maxn][maxn];
int size;

struct AC_Automata{
    int tire[N][4];
    int val[N];
    int last[N];
    int fail[N];
    int count[N];
    int tot;

    void init(){
        tot = 1;
        val[0] = fail[0] = last[0]  =0;
        memset(tire[0],0,sizeof(tire[0]));
    }

    int get_id(char ch){
        if(ch=='A') return 0;
        if(ch=='C') return 1;
        if(ch=='G') return 2;
        if(ch=='T') return 3;

    }
    void insert(char *s,int v){
        int len = strlen(s);
        int root = 0;
        for(int i=0;i<len;i++){
            int id = get_id(s[i]);
            if(tire[root][id]==0){
                tire[root][id] = tot;
                memset(tire[tot],0,sizeof(tire[tot]));
                val[tot++] = 0;
            }
            root = tire[root][id];  
        }
        val[root] = 1;  
    }

    void build(){
        while(!q.empty()) q.pop();
        last[0] = fail[0] = 0;
        for(int i=0;i<4;i++){
            int root = tire[0][i];
            if(root!=0){
                fail[root] = 0;
                last[root] = 0;
                q.push(root);
            }
        }
        while(!q.empty()){
            int k = q.front();
            q.pop();
            for(int i=0;i<4;i++){
                int u = tire[k][i];
                if(u==0){
                    tire[k][i] = tire[fail[k]][i];
                    continue;
                }
                q.push(u);
                fail[u] =tire[fail[k]][i];
                last[u] = val[fail[u]]?fail[u]:last[fail[u]];
            }
        }
    } 

    void get_count(int i){
        if(val[i]){
            count[val[i]]++;
            get_count(last[i]);
        }
    }

    void query(char *s){
        int len = strlen(s);
        int j=0;
        for(int i=0;i<len;i++){
            int id = get_id(s[i]);
            while(j && tire[j][id]==0){
                j=fail[j];
            }
            j = tire[j][id];
            if(val[j])
                get_count(j);
            else if(last[j])
                get_count(j);
        }
    }
}ac;


#define Mod 100000
#define N 200
ll a[N][N];
int n;

void Multi(ll a[][N], ll b[][N], ll c[][N]) {
    for (int i=0; i<n; i++)
        for (int j=0; j<n; j++) {
            c[i][j] = 0;
            for (int k=0; k<n; k++)
                c[i][j] = (c[i][j] + a[i][k]*b[k][j]) % Mod;
        }
}

void copy(ll d[][N], ll s[][N]) {
    for (int i=0; i<n; i++) for (int j=0; j<n; j++)
        d[i][j] = s[i][j];
}

void PowerMod(ll a[][N], ll b) {
    ll t[N][N], ret[N][N];
    for (int i=0; i<n; i++) ret[i][i] = 1;
    while (b) {
        if (b & 1) { Multi(ret, a, t); copy(ret, t); }
        Multi(a, a, t); copy(a, t);
        b >>= 1;
    }
    copy(a, ret);
}
void init() {
    n = ac.tot;
    int u;
    memset(a, 0, sizeof(a));
    for (int i=0; i<n; i++) if (!ac.val[i] && !ac.last[i]) {
        for (int j=0; j<4; j++) {
            u = ac.tire[i][j];
            if (!ac.val[u] && !ac.last[u]) a[i][u]++;
        }
    }
}
int main() {
    char s[12];
    int m;
    ll b;
    while (scanf("%d %lld", &m, &b) == 2) {
        ac.init();
        for (int i=1; i<=m; i++) {
            scanf(" %s", s); ac.insert(s, i);
        }
        ac.build();
        init();
        PowerMod(a, b);
 
        ll sum = 0;
        for (int i=0; i<n; i++) sum = (sum + a[0][i]) % Mod;
        printf("%lld\n",sum);
    }
 
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Tianwell/p/11396233.html