HYSBZ 1036 树的统计Count(树链剖分+线段树)

1036: [ZJOI2008]树的统计Count

Time Limit: 10 Sec   Memory Limit: 162 MB
Submit: 19818   Solved: 8066
[ Submit][ Status][ Discuss]

Description

  一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成
一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 I
II. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

Input

  输入的第一行为一个整数n,表示节点的个数。接下来n – 1行,每行2个整数a和b,表示节点a和节点b之间有
一条边相连。接下来n行,每行一个整数,第i行的整数wi表示节点i的权值。接下来1行,为一个整数q,表示操作
的总数。接下来q行,每行一个操作,以“CHANGE u t”或者“QMAX u v”或者“QSUM u v”的形式给出。 
对于100%的数据,保证1<=n<=30000,0<=q<=200000;中途操作中保证每个节点的权值w在-30000到30000之间。

Output

  对于每个“QMAX”或者“QSUM”的操作,每行输出一个整数表示要求输出的结果。

Sample Input

4
1 2
2 3
4 1
4 2 1 3
12
QMAX 3 4
QMAX 3 3
QMAX 3 2
QMAX 2 3
QSUM 3 4
QSUM 2 1
CHANGE 1 5
QMAX 3 4
CHANGE 3 6
QMAX 3 4
QMAX 2 4
QSUM 3 4

Sample Output

4
1
2
2
10
6
5
6
5
16

HINT

Source


【思路】

朴素算法对每个操作都执行一遍深搜,不可取,需要一种能够记住点对点路径或者部分路径的方式。所以对整棵树进行轻重链剖分,把同一条链的节点映射到一个连续区间,再对其使用线段树维护。思想是:每个节点都属于某一条链,每条链都有一个唯一顶端,那么如果两个点具有同一个顶端,则位于同一条链上,那么其间的路径信息便可较快获取。


【代码】

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int MAXN = 30005, INF = 0x3f3f3f3f;

struct edge {
	int to, next;
}; 

struct segment {
	int left, right, mid;
	int sum, mx;
};

int n, q, cnt, tot;
int head[MAXN], p[MAXN], in[MAXN], id[MAXN],fa[MAXN], top[MAXN], sz[MAXN], max_son[MAXN], deep[MAXN];
edge e[MAXN << 1];
segment tree[MAXN << 2];

void addedge(int from, int to)
{
	++cnt;
	e[cnt].to = to;
	e[cnt].next = head[from];
	head[from] = cnt;
} 

void dfs_1(int u, int father, int depth)
{
	deep[u] = depth;
	fa[u] = father; max_son[u] = 0;
	sz[u] = 1;
	for (int i = head[u]; i != 0; i = e[i].next) {
		int v = e[i].to;
		if (v == fa[u]) continue;
		dfs_1(v, u, depth + 1);
		sz[u] += sz[v];
		if (sz[max_son[u]] < sz[v]) max_son[u] = v;
	}
}

void dfs_2(int u, int tp)
{
	in[u] = ++tot;
	id[tot] = u;
	top[u] = tp;
	if (max_son[u] != 0) dfs_2(max_son[u], tp);
	for (int i = head[u]; i != 0; i = e[i].next) {
		int v = e[i].to;
		if (v == fa[u] || v == max_son[u]) continue;
		dfs_2(v, v);
	}
}

void build(int left, int right, int root)
{
	tree[root].left = left;
	tree[root].right = right;
	tree[root].mid = (left + right) >> 1;
	if (left == right) {
		tree[root].mx = tree[root].sum = p[id[left]];
		return;
	}
	build(left, tree[root].mid, root << 1);
	build(tree[root].mid + 1, right, root << 1 | 1);
	tree[root].sum = tree[root << 1].sum + tree[root << 1 | 1].sum;
	tree[root].mx = max(tree[root << 1].mx, tree[root << 1 | 1].mx);
}

void modify(int index, int num, int root)
{
	if (tree[root].left == tree[root].right) {
		tree[root].mx = tree[root].sum = num;
		return;
	}
	if (index <= tree[root].mid) modify(index, num, root << 1);
	if (index >= tree[root].mid + 1) modify(index, num, root << 1 | 1);
	tree[root].mx = max(tree[root << 1].mx, tree[root << 1 | 1].mx);
	tree[root].sum = tree[root << 1].sum + tree[root << 1 | 1].sum;
}

int sum_query(int l, int r, int root)
{
	if (l <= tree[root].left && tree[root].right <= r) return tree[root].sum;
	int ans = 0;
	if (l <= tree[root].mid) ans += sum_query(l, r, root << 1);
	if (r >= tree[root].mid + 1) ans += sum_query(l, r, root << 1 | 1);
	return ans;
}

int max_query(int l, int r, int root)
{
	if (l <= tree[root].left && tree[root].right <= r) return tree[root].mx;
	int ans = -INF;
	if (l <= tree[root].mid) ans = max(ans, max_query(l, r, root << 1));
	if (r >= tree[root].mid + 1) ans = max(ans, max_query(l, r, root << 1 | 1));
	return ans;
}

int main()
{
	cnt = 0;
	memset(head, 0, sizeof(head));
	scanf("%d", &n);
	for (int i = 1; i <= n - 1; i++) {
		int a, b; scanf("%d %d", &a, &b);
		addedge(a, b);
		addedge(b, a);
	}
	for (int i = 1; i <= n; i++) scanf("%d", &p[i]);
	dfs_1(1, 0, 1);
	tot = 0;
	dfs_2(1, 1);
	build(1, n, 1);
	scanf("%d", &q);
	while (q--) {
		char mes[7];
		int u, v; scanf("%s %d %d", mes, &u, &v);
		if (mes[0] == 'C') modify(in[u], v, 1);
		if (mes[1] == 'M') {
			int ans = -INF;
			while (top[u] != top[v]) {
				if (deep[top[u]] < deep[top[v]]) swap(u, v);
				ans = max(ans, max_query(in[top[u]], in[u], 1));
				u = fa[top[u]];
			} 
			if (deep[u] > deep[v]) swap(u, v);
			ans = max(ans, max_query(in[u], in[v], 1));
			printf("%d\n", ans);
		}
		if (mes[1] == 'S') {
			int ans = 0;
			while (top[u] != top[v]) {
				if (deep[top[u]] < deep[top[v]]) swap(u, v);
				ans += sum_query(in[top[u]], in[u], 1);
				u = fa[top[u]];
			}
			if (deep[u] > deep[v]) swap(u, v);
			ans += sum_query(in[u], in[v], 1);
			printf("%d\n", ans);
		}
	}
	return 0;
}


猜你喜欢

转载自blog.csdn.net/shili_xu/article/details/78885673
今日推荐