Codeforces 1140E DP

题意:给你一个数组,如果数组中的某个位置是-1那就可以填1到m的数字中的一个,但是要遵守一个规则:不能出现长度为奇数回文的子串,问合法的填法有多少种?

思路:不出现长度为奇数的回文子串,只需不出现长度为3的回文子串就可以了,那么i位置和i - 2位置填的数字不能一样。所以,我们可以把这个数组拆成2部分,所有的奇数位置和所有的偶数位置分别成一个串,之后吧两个串的答案乘起来就是答案了。每个串肯定是下列情况中的一种或多种构成。

1:全是-1, 那么明显有k * (k - 1) ^ (n - 1)种答案。

2:有一遍是大于0的数,假设长度为n,那么答案就是(k - 1) ^ n。

3:两边都是大于0的数,这个需要DP预处理后得到答案。

设dp[i][0 / 1]表示长度为i的-1串,当前最后一个-1填的数字与最前面的数字不同/相同,合法的方案数。

转移是这样:dp[i][1] = dp[i - 1][0], 因为i - 1位置填的数字与最前面的不同,所有填一个与最前面位置相同的数字是合法的,所有直接转移。

dp[i][0] = (k - 1) * dp[i - 1][1] + (k - 2) * dp[i - 1][0],前半部分很好理解,加上一个和最前面不等的数就可以了。后半部分,每个数不能和最前面相等,也不能和自己相等,所有是k - 2个转移。

代码:

#include <bits/stdc++.h>
#define LL long long
using namespace std;
const LL mod = 998244353ll;
const int maxn = 200010;
LL dp[maxn][2];
LL a[maxn], b[maxn];
int tot = 0; 
LL m;
int n;
LL qpow(LL x, LL y) {
	LL ans = 1;
	for (; y; y >>= 1) {
		if(y & 1) ans = (ans * x) % mod;
		x = (x * x) % mod;
	}
	return ans;
}
LL solve() {
	int pos = 1, cnt = 0;
	LL ans = 1;
	while(pos <= tot) {
		while(a[pos] > 0 && pos <= tot) {
			if(pos > 1 && a[pos] == a[pos - 1]) return 0;
			pos++;
		} 
		int pre = pos - 1;
		while(a[pos] < 0 && pos <= tot) {
			pos++;
			cnt++;
		}
		if(pre == 0 && pos == tot + 1) return (m * qpow(m - 1, tot - 1)) % mod;
		else if(pre == 0 || pos == tot + 1) ans = (ans * qpow(m - 1, pos - pre - 1)) % mod;
		else {
			int flag = (int)(a[pre] == a[pos]);
			if(flag == 1)
				ans = (ans * dp[pos - pre - 1][0]) % mod;
			else {
				LL tmp = (dp[pos - pre - 1][1] + (((dp[pos - pre - 1][0] * qpow(m - 1, mod - 2)) % mod) * (m - 2)) % mod) % mod;
				ans = (ans * tmp) % mod; 
			} 	
		} 
	}
	return ans;	
}
int main() {
	scanf("%d%lld", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%lld", &b[i]);
	}
	dp[1][0] = m - 1;
	dp[1][1] = 0;
	for (int i = 2; i <= n; i++) {
		dp[i][0] = ((dp[i - 1][1] * (m - 1)) % mod + (dp[i - 1][0] * (m - 2)) % mod) % mod;
		dp[i][1] = dp[i - 1][0];
	}
	LL ans1, ans2;
	for (int i = 1; i <= n; i += 2)
		a[++tot] = b[i];
	ans1 = solve();
	tot = 0;
	for (int i = 2; i <= n; i += 2)
		a[++tot] = b[i];
	ans2 = solve();
	printf("%lld\n", (ans1 * ans2) % mod);	
}
 

  

猜你喜欢

转载自www.cnblogs.com/pkgunboat/p/10836649.html