Codeforces 750E 线段树DP

题意:给你一个字符串,有两种操作:1:把某个位置的字符改变。2:询问l到r的子串最少需要删除多少个字符,使得这个子串含有2017子序列,并且没有2016子序列?

思路:线段树上DP,我们设状态0, 1, 2, 3, 4分别为: null, 2, 20, 201, 2017的最小花费,我们用线段树来维互状态转移的花费矩阵,合并相邻的两个子串的时候直接转移即可。

代码:

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
#define ls (o << 1)
#define rs (o << 1 | 1)
using namespace std;
const int maxn = 200010;
int a[maxn];
char s[maxn];
struct node {
	int f[5][5];
	void init(int x) {
		for (int i = 0; i < 5; i++) {
			for (int j = 0; j < 5; j++) {
				if(i == j) continue;
				f[i][j] = INF;
			}
		}
		if(x == 2) {
			f[0][0] = 1, f[0][1] = 0;
		} else if (x == 0) {
			f[1][1] = 1, f[1][2] = 0;			
		} else if (x == 1) {
			f[2][2] = 1, f[2][3] = 0;
		} else if (x == 7) {
			f[3][3] = 1, f[3][4] = 0;
		} else if (x == 6) {
			f[3][3] = 1;
			f[4][4] = 1;
		} else if (x == -1){
			for (int i = 0; i < 5; i++)
				f[i][i] = INF;
		} else {
			for (int i = 0; i < 5; i++)
				f[i][i] = 0;
		}
	}
	
	void print() {
		for (int i = 0; i < 5; i++) {
			for (int j = 0; j < 5; j++) {
				if(f[i][j] == INF) printf("inf ");
				else printf("%d ", f[i][j]);
			}
			printf("\n");
		}
	}
};
node tr[maxn * 4];
node merge(node t1, node t2) {
	node ans;
	ans.init(-1);
//	ans.init(-1);
//	printf("ans\n");
//	ans.print();
//	printf("t1\n");
//	t1.print();
//	printf("t2\n");
//	t2.print();
	for (int i = 0; i < 5; i++) {
		for (int j = i; j < 5; j++) {
			for (int k = i; k <= j; k++) {
				ans.f[i][j] = min(ans.f[i][j], t1.f[i][k] + t2.f[k][j]); 
			}
		}
	}
//	printf("ans\n");
//	ans.print();
	return ans;
}
void build(int o, int l, int r) {
	if(l == r) {
		tr[o].init(a[l]);
		return;
	}
	int mid = (l + r) >> 1;
	build(ls, l, mid);
	build(rs, mid + 1, r);
	tr[o] = merge(tr[ls], tr[rs]);
}
void update(int o, int l, int r, int ql, int qr, int val) {
	if(l == r) {
		tr[o].init(val);
		return;
	}
	int mid = (l + r) >> 1;
	if(ql <= mid) update(ls, l, mid, ql, qr, val);
	if(qr > mid) update(rs, mid + 1, r, ql, qr, val);
	tr[o] = merge(tr[ls], tr[rs]);
}
node query(int o, int l, int r, int ql, int qr) {
	if(l >= ql && r <= qr) {
		return tr[o];
	}
	int mid = (l + r) >> 1;
	node ans;
	ans.init(-1);
	if(ql <= mid && qr > mid) ans = merge(query(ls, l, mid, ql, qr), query(rs, mid + 1, r, ql, qr));
	else if(ql <= mid) ans = query(ls, l, mid, ql, qr);
	else if(qr > mid) ans = query(rs, mid + 1, r, ql, qr);
	return ans; 
}
int main() {
	int n, m, l, r;
	scanf("%d%d", &n, &m);
	scanf("%s", s + 1);
	for (int i = 1; i <= n; i++) {
		a[i] = s[i] - '0';
	}
	build(1, 1, n);
	for (int i = 1; i <= m; i++) {
		scanf("%d%d", &l, &r);
		node ans = query(1, 1, n, l, r);
		if(ans.f[0][4] == INF) printf("-1\n");
		else printf("%d\n", ans.f[0][4]);
	}
} 

  

猜你喜欢

转载自www.cnblogs.com/pkgunboat/p/11488340.html