Codeforces 1178F DP

题意:有一张白纸条,你需要给这张纸条染色。染色从颜色1开始染色,每次选择纸条的一段染色时,这一段的颜色必须是相同的。现在给你染色后的纸条,问有多少种染色方案?

F1: 思路:最开始的想法是以染色顺序为一个维度,然后染色区间为另外两个维度去DP,但是最后发现不可以,因为之前的所有的染色对后面的影响不确定,只用在染第几种颜色是无法确定现在可以染色的区间的,还无法记忆化,只能去看题解。。。官方题解的DP设计的比较巧妙。我们DP并不关心染色顺序,我们的关注点在区间。我们可以发现,对于一个区间,第一次开始染的颜色一定是标号最小的那种颜色。假设当前区间是[l, r],标号最小的颜色c的位置是p,并且在这个区间内染上颜色c的区间是[a, b],那么区间被分为了4部分:[l, a - 1], [a, p - 1], [p + 1, b], [b + 1, r],递归的求解这四部分,再把答案乘起来就可以了。容易发现a和b的枚举是独立的,那么把枚举的a的答案和枚举的b的答案分别求出,乘起来就可以了。

代码:

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define pii pair<int, int>
#define db double
#define LL long long
using namespace std;
const int maxn = 510;
const LL mod = 998244353;
LL dp[maxn][maxn];
int n, m;
pii c[maxn];
set<int> s, s1;
set<int> ::iterator it;
int L[maxn], R[maxn];
int cnt[maxn];
int mi[maxn][maxn];
LL dfs(int l, int r) {
	if(r <= l) return 1;
	if(dp[l][r] != -1) return dp[l][r];
	int p =  mi[l][r];
	LL sum1 = 0, sum2 = 0;
	for (int i = l - 1; i < p; i++) {
		sum1 += (dfs(l, i) * dfs(i + 1, p - 1)) % mod;
	}
	sum1 %= mod;
	for (int j = p; j <= r; j++) {
		sum2 += (dfs(p + 1, j) * dfs(j + 1, r)) % mod;
	}
	sum2 %= mod;
	dp[l][r] = (sum1 * sum2) % mod;
	return dp[l][r];
}
int main() {
	scanf("%d%d", &n, &m);
	memset(dp, -1, sizeof(dp));
	for (int i = 1; i <= n; i++) {
		scanf("%d", &c[i].first);
		c[i].second = i;
		L[i] = 1, R[i] = n;
	}
	s.insert(0);
	s.insert(n + 1);
	sort(c + 1, c + 1 + n);
	for (int i = 1; i <= n; i++) {
		L[i] = (*(--s.lower_bound(c[i].second)));
		R[i] = (*s.lower_bound(c[i].second));
		L[i]++, R[i]--;
		for (int j = L[i]; j <= c[i].second; j++)
			for (int k = c[i].second; k <= R[i]; k++)
				mi[j][k] = c[i].second;
		s.insert(c[i].second);
	}
	cout << dfs(1, n) << endl;
}

F2: 现在纸条的范围是1e6,不能直接区间DP了。不过通过观察发现,如果纸条上有多个相邻的位置颜色相同,那么可以缩成一个位置,因为这些位置整体要么都不覆盖,要么全部覆盖,和一个位置的贡献相同。容易发现如果方案合法缩点之后点数不会超过1000,所有超过1000的可以直接输出0。缩点之后和F1已经很像了,不过有些小细节需要注意。首先需要判断方案是否合法,容易发现如果有两个颜色相同的位置中间有标号比它们小的颜色,方案不合法。这个直接暴力判断就可以了。其次,在DP计算方案时,除了F1中的四个部分,区间内有可能有多最小标号的位置,这些位置之间的区间也要乘起来。最后,我们不能直接像F1那样枚举区间,比如这种例子:

8 10
8 4 3 5 7 5 2 6 2 1

容易发现,我们在枚举2的覆盖区间时,覆盖区间的左端点不能在2个5之间(包括最右边的5),如果2在这里覆盖了,5这种情况是不可能出现的。这个用指针之类的处理一下就可以了。

代码:

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define pii pair<int, int>
#define db double
#define LL long long
using namespace std;
const int maxn = 1010;
const LL mod = 998244353;
LL dp[maxn][maxn];
int n, m;
int c[1000010];
int L[maxn], R[maxn];
int cnt[maxn], last[maxn], Next[maxn];
int mi[maxn][maxn];
vector<pii> re[maxn];
LL dfs(int l, int r) {
	if(r <= l) return 1;
	if(dp[l][r] != -1) return dp[l][r];
	int lp =  L[mi[l][r]], rp = R[mi[l][r]];
	LL sum1 = 0, sum2 = 0, sum3 = 1;
	int now = mi[l][r];
	for (int i = 0; i < re[now].size(); i++) {
		sum3 = (sum3 * dfs(re[now][i].first, re[now][i].second)) % mod;
	}
	for (int i = l - 1; i < lp; i = Next[i]) {
		sum1 += (dfs(l, i) * dfs(i + 1, lp - 1)) % mod;
	}
	sum1 %= mod;
	for (int j = rp; j <= r; j = Next[j]) {
		sum2 += (dfs(rp + 1, j) * dfs(j + 1, r)) % mod;
	}
	sum2 %= mod;
	dp[l][r] = (((sum1 * sum2) % mod) * sum3) % mod;
	return dp[l][r];
}
int main() {
	scanf("%d%d", &n, &m);
	for (int i = 1; i <= m; i++) {
		scanf("%d", &c[i]);
	}
	int tot = 0;
	tot = 1;
	for (int i = 2; i <= m; i++) {
		if(c[i] == c[i - 1]) {
			continue;
		} else {
			c[++tot] = c[i];
		}
	}
	if(tot > 1000) {
		printf("0\n");
		return 0;
	}
	memset(dp, -1, sizeof(dp));
	memset(last, -1, sizeof(last));
	for (int i = 1; i <= tot; i++) {
		mi[i][i] = c[i];
		for (int j = i + 1; j <= tot; j++) {
			mi[i][j] = min(c[j], mi[i][j - 1]);
		}
		L[i] = tot + 1;
		R[i] = 0;
	}
	bool flag = 0;
	for (int i = 1; i <= tot; i++) {
		L[c[i]] = min(L[c[i]], i);
		R[c[i]] = max(R[c[i]], i);
		if(last[c[i]] != -1) {
			if(mi[last[c[i]] + 1][i - 1] < c[i]) {
				flag = 1;
				break;
			}
			re[c[i]].push_back(make_pair(last[c[i]] + 1, i - 1));
		}
		last[c[i]] = i;
	}
	for (int i = 0; i <= tot; i++) {
		Next[i] = R[c[i + 1]];
	}
	Next[tot] = tot + 1;
	if(flag) {
		printf("0\n");
		return 0;
	}
	cout << dfs(1, tot) << endl;
	return 0;
}

  

猜你喜欢

转载自www.cnblogs.com/pkgunboat/p/11222356.html