poj-2778 DNA Sequence[AC自动机+矩阵快速幂]

题目地址
因为n很大,很自然想到用矩阵快速幂。
首先要知道对于一个01矩阵,如果m[i, j] = 1,表示从i到j有一条路,那么这个矩阵在自乘n次后,m[i,j]表示从i->j走n步的的方案数(离散书上的图论相关内容)
所以只需要找到哪些边是可以走的,然后跑一下矩阵快速幂就行了。
先把病毒放到ac自动机里面,对于结尾的节点标记一下,然后把可以走的边放到矩阵里面跑一下矩阵快速幂就行了。

#include<queue>
#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
const int mod = 100000;
typedef long long ll;
int Hash[125];
struct AC{
	int nex[300][4], tot, root;
	int f[300], ed[300];
	int newnode() {
		for(int i = 0; i < 4; i++) nex[tot][i] = -1;
		ed[tot] = 0;
		return tot++;
	}
	void init() {
		tot = 0;
		root = newnode();
	}
	void insert(char *s) {
		int u = root, len = strlen(s);
		for(int i = 0; i < len; i++) {
			int ch = Hash[s[i]];
			if(nex[u][ch] == -1) nex[u][ch] = newnode();
			u = nex[u][ch];
		}
		ed[u] = 1;
	}
	void getfail() {
		queue<int>Q;
		for(int i = 0; i < 4; i++) {
			if(nex[root][i] == -1) nex[root][i] = root;
			else {
				f[nex[root][i]] = root;
				Q.push(nex[root][i]);
			}
		}
		while(!Q.empty()) {
			int u = Q.front();Q.pop();
			ed[u] |= ed[f[u]];
			for(int i = 0; i < 4; i++) {
				if(nex[u][i] == -1) nex[u][i] = nex[f[u]][i];
				else {
					f[nex[u][i]] = nex[f[u]][i];
					Q.push(nex[u][i]);
				}
			}
		}
	}
}ac;
class matrix{
public:
	ll a[105][105];
	int n, m;
	matrix(int n, int m) {
		this->n = n;
		this->m = m;
		memset(a, 0, sizeof(a));
	}
	matrix operator *(matrix &b) {
		matrix c(n, b.m);
		for(int i = 1; i <= n; i++) {
			for(int j = 1; j <= b.m; j++) {
				for(int k = 1; k <= m; k++) {
					c.a[i][j] += a[i][k] * b.a[k][j];
                	if (c.a[i][j] > mod) {
                 	   c.a[i][j] %= mod;
                	}
				}
			}
		}
		return c;
	}
	matrix pow(int x) {
		matrix res(n, n), A(n, n);
		for(int i = 1; i <= n; i++) res.a[i][i] = 1;
		for(int i = 1; i <= n; i++) {
			for(int j = 1; j <= n; j++) {
				A.a[i][j] = a[i][j];
			}
		}
		while(x) {
			if(x&1) res = res*A;
			A = A*A;
			x >>= 1;
		}
		return res;
	}
};
char ss[10];
int m, n;
int main() {
	Hash['A'] = 0;Hash['C'] = 1;Hash['T'] = 2;Hash['G'] = 3;
	while(scanf("%d%d", &m, &n) == 2) {
		ac.init();
		for(int i = 1; i <= m; i++) {
			scanf("%s", ss);
			ac.insert(ss);
		}
		ac.getfail();
		matrix A(ac.tot, ac.tot);
		for(int i = 0; i < ac.tot; i++) {
			if(ac.ed[i]) continue;
			for(int j = 0; j < 4; j++) {
				int t = ac.nex[i][j];
				if(!ac.ed[t]) A.a[i+1][t+1]++;
			}
		}
		matrix ans(ac.tot, ac.tot);
	  	ans = A.pow(n);
	    ll sum = 0;
	    for (int i = 1; i <= ac.tot; ++i) {
	        sum = sum + ans.a[1][i];
	        if (sum > mod) {
	            sum %= mod;
	        }
	    }
	    printf("%lld\n", sum);
	}
}

猜你喜欢

转载自blog.csdn.net/qq_39921637/article/details/89074586