【树链剖分模板】 BZOJ1036 树的统计

在dfs1里先处理出重儿子

BZOJ传送门:点击打开链接

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<cctype>
using namespace std;
const int inf = 0x7f7f7f7f, maxn = 30007;
#define ls p<<1
#define rs p<<1|1
struct edge {
	int v, nxt;
}e[maxn << 1];
int head[maxn], eid = 0, siz[maxn], dep[maxn], fa[maxn], top[maxn], l[maxn], tot, val[maxn];
int sum[maxn << 2], n, q, maxx[maxn << 2],son[maxn];
/*
p邻接表 siz子树大小 dep结点深度 fa父结点 top重链头 l是dfs序
*/
void insert(int u, int v) {
	e[++eid].v = v; e[eid].nxt = head[u]; head[u] = eid;
	e[++eid].v = u; e[eid].nxt = head[v]; head[v] = eid;
}


char cmd[10];

void dfs1(int u) {	
	siz[u] = 1;
	int i,v;
	for (i = head[u]; i; i = e[i].nxt) {
		if (!siz[v = e[i].v]) {
			dep[v] = dep[u] + 1;
			fa[v] = u;
			dfs1(v);
			siz[u] += siz[v];
			if (siz[v] > siz[son[u]]) 
               son[u] = v;
		}
	}
}
void dfs2(int u, int t) {
	l[u] = ++tot;
	top[u] = t;
	int i,v;

	if (son[u] != 0) 
		dfs2(son[u], t);
		
	for (i = head[u]; i; i = e[i].nxt) {
		v = e[i].v;
		
		if (dep[v]>dep[u] && v != son[u])
			dfs2(v, v);	
	}
}

void pushup(int p) {
	sum[p] = sum[ls] + sum[rs];
	maxx[p] = max(maxx[ls], maxx[rs]);
}
void modify(int p, int l, int r, int x, int c) {	
	if (l == r) {
		sum[p] = maxx[p] = c;
		return;
	}
	int mid = (l + r) >> 1;
	if (x <= mid) 
         modify(ls, l, mid, x, c);
	else 
         modify(rs, mid + 1, r, x, c);
	pushup(p);
}
int querysum(int p, int l, int r, int x, int y) {
	if (x <= l&&r <= y) return sum[p];
	int mid = (l + r) >> 1, res = 0;
	if (x <= mid) res += querysum(ls, l, mid, x, y);
	if (y > mid) res += querysum(rs, mid + 1, r, x, y);
	return res;
}
int querymax(int p, int l, int r, int x, int y) {
	if (x <= l&&r <= y) return maxx[p];
	int mid = (l + r) >> 1, res = -inf;	//点权在-30000~30000
	if (x <= mid) res = max(res, querymax(ls, l, mid, x, y));
	if (y > mid) res = max(res, querymax(rs, mid + 1, r, x, y));
	return res;
}
int getsum(int x, int y) {
	int res = 0;
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]])
			swap(x, y);
		res += querysum(1, 1, n, l[top[x]], l[x]);	
		x = fa[top[x]];	
	}
	if (l[x] > l[y])
		swap(x, y);	
	res += querysum(1, 1, n, l[x], l[y]);
	return res;
}
int getmax(int x, int y) {
	int res = -inf;
	while (top[x] != top[y]) {
		if (dep[top[x]] < dep[top[y]]) 
            swap(x, y);
		res = max(res, querymax(1, 1, n, l[top[x]], l[x]));
		x = fa[top[x]];
	}
	if (l[x] > l[y]) 
       swap(x, y);
	res = max(res, querymax(1, 1, n, l[x], l[y]));
	return res;
}
int u;
inline int read() {
	int s = 0, f = 1; char c = getchar(); while (c<'0' || c>'9') { if (c == '-') f = -1; c = getchar(); }
	while (c >= '0'&&c <= '9') { s = s * 10 + c - '0'; c = getchar(); }
	return s*f;
}
int main() {
	n = read();
	int i,v;
	for (i = 1; i < n; i++) {
		u = read(); v = read();
		insert(u, v);
	}
	for (i = 1; i <= n; i++)
		val[i] = read();
	dfs1(1);
	dfs2(1, 1);
	for (int i = 1; i <= n; i++) 
        modify(1, 1, n, l[i], val[i]);
	q = read();
	char c,j=0;
	for (int i = 1; i <= q; i++) {
		
		scanf("%s" , cmd);
		u = read();
		v = read();
		switch (cmd[1]) {
		case 'H': val[u] = v; modify(1, 1, n, l[u], v); break;
		case 'M': printf("%d\n", getmax(u, v)); break;
		default: printf("%d\n", getsum(u, v));
		}
	}
	return 0;
}


 
 

猜你喜欢

转载自blog.csdn.net/qq_35755187/article/details/80049229