SPOJ - COT Count on a tree(树上第k大模板题)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Cymbals/article/details/82998385

You are given a tree with N nodes. The tree nodes are numbered from 1 to N. Each node has an integer weight.

We will ask you to perform the following operation:

u v k : ask for the kth minimum weight on the path from node u to node v

Input
In the first line there are two integers N and M. (N, M <= 100000
In the second line there are N integers. The ith integer denotes the weight of the ith node
In the next N-1 lines, each line contains two integers u v, which describes an edge (u, v)
In the next M lines, each line contains three integers u v k, which means an operation asking for the kth minimum weight on the path from node u to node v
Output
For each operation, print its result.

给一棵无根树,m个询问,问树上区间第k大。

在接触树上第k大之前先做了HDU-4757,这题是树上的可持久化01字典树,所以一上来就很有思路。

树上的路径就是一个lca,像上面这题一样,每一个节点以父亲为last版本,新建可持久化线段树,然后在查询的时候,左右树相加再减去一个算重了的lca树,减去不用考虑的lca的father树,就差分出了树上路径树。

ac代码:

#include<bits/stdc++.h>
using namespace std;

const int maxn = 100005;
int n, m, in[maxn], num[maxn];
vector<int> G[maxn];
int p[maxn][20], dep[maxn];

struct Node {
	int val;
	Node *lc, *rc;
} *root[maxn], pool[maxn * 20], *tail = pool, *null;


Node* update(Node *pre, int l, int r, int pos) {
	if(pos < l || pos > r) {
		return pre;
	}
	Node *nd = ++tail;
	if(l == r) {
		nd->val = pre->val + 1;
		return nd;
	}
	int mid = (l + r) >> 1;
	nd->lc = update(pre->lc, l, mid, pos);
	nd->rc = update(pre->rc, mid + 1, r, pos);
	nd->val = nd->lc->val + nd->rc->val;
	return nd;
}


void dfs(int u, int fa) {
	root[u] = update(root[fa], 1, n, num[u]);
	p[u][0] = fa;
	dep[u] = dep[fa] + 1;
	for(int i = 0; i < G[u].size(); i++) {
		int v = G[u][i];
		if(v == fa) {
			continue;
		}
		dfs(v, u);
	}
}

int lca(int u, int v) {
	if(dep[u] > dep[v]) swap(u, v);
	for(int i = 0; i < 20; i++) {
		if((dep[v] - dep[u]) >> i & 1) {
			v = p[v][i];
		}
	}
	if(v == u) return u;
	for(int i = 20 - 1; i >= 0; i--) {
		if(p[u][i] != p[v][i]) {
			u = p[u][i];
			v = p[v][i];
		}
	}
	return p[u][0];
}

int query(Node *topfa, Node *top, Node *u, Node *v, int l, int r, int k) {
	if(l == r) {
		return l;
	}
	int mid = (l + r) >> 1;
	int cnt = (u->lc->val + v->lc->val) - top->lc->val - topfa->lc->val;
	if(k <= cnt) {
		return query(topfa->lc, top->lc, u->lc, v->lc, l, mid, k);
	}
	return query(topfa->rc, top->rc, u->rc, v->rc, mid + 1, r, k - cnt);
}


void init() {
	null = ++tail;
	null->val = 0;
	null->lc = null->rc = null;
	for(int i = 0; i <= n; i++) {
		root[i] = null;
	}
	sort(in + 1, in + n + 1);
	unique(in + 1, in + n + 1);
	for(int i = 1; i <= n; i++) {
		num[i] = lower_bound(in + 1, in + n + 1, num[i]) - in;
	}
	dfs(1, 0);
	for(int i = 1; i < 20; i++) {
		for(int j = 1; j <= n; j++) {
			p[j][i] = p[p[j][i - 1]][i - 1];
		}
	}
}


int main() {
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; i++) {
		scanf("%d", &in[i]);
		num[i] = in[i];
		G[i].clear();
	}

	int u, v, k;
	for(int i = 1; i <= n - 1; i++) {
		scanf("%d%d", &u, &v);
		G[u].push_back(v);
		G[v].push_back(u);
	}

	init();
	while(m--) {
		scanf("%d%d%d", &u, &v, &k);
		int top = lca(u, v);
		int ans = in[query(root[p[top][0]], root[top], root[u], root[v], 1, n, k)];
		printf("%d\n", ans);
	}
	return 0;
}

猜你喜欢

转载自blog.csdn.net/Cymbals/article/details/82998385
今日推荐