Portal
Title description
Given a tree containing N nodes. Tree nodes are numbered from 1 to N. Each node has an integer weight, and query the weight of the k-th smallest node on the path from node u to node v.
analysis
Introductory questions for the chairman tree on the tree
We can deal with this problem with the idea similar to the difference on the tree, each node is actually updated in the case of its parent node, and then to find the weight on a section of the path, just ask for tr[u] + tr[v]-tr[LCA(u,v)]-tr[fa[LCA(u,v)]], then search for the kth smallest interval
Code
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
const int N = 1e5 + 10;
int h[N],ne[N * 2],e[N * 2],pppp;
int n,m;
int a[N];
int d[N],f[N][25];
bool st[N];
vector<int> nums;
struct Node{
int l,r;
int cnt;
}tr[10000005];
int root[N],idx;
void add(int x,int y){
ne[pppp] = h[x],e[pppp] = y,h[x] = pppp++;
}
int find(int x){
return lower_bound(nums.begin(),nums.end(),x) - nums.begin();
}
int build(int l,int r){
int p = ++idx;
if(l == r) return p;
int mid = l + r >> 1;
tr[p].l = build(l,mid),tr[p].r= build(mid + 1,r);
return p;
}
int insert(int p,int l,int r,int x){
int q = ++idx;
tr[q] = tr[p];
if(l == r){
tr[q].cnt++;
return q;
}
int mid = l + r >> 1;
if(x <= mid) tr[q].l = insert(tr[p].l,l,mid,x);
else tr[q].r = insert(tr[p].r,mid + 1,r,x);
tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
return q;
}
void dfs(int u, int fa) {
root[u] = insert(root[fa],0,nums.size() - 1,find(a[u]));
f[u][0] = fa;
d[u] = d[fa] + 1;
for (int i = 1;i <= 18;i++)
f[u][i] = f[f[u][i-1]][i-1];
for (int i = h[u];~i;i = ne[i]) {
int v = e[i];
if (v == fa) continue;
dfs(v, u);
}
}
int LCA(int u, int v) {
if (d[u] < d[v]) swap(u, v);
for (int i = 18; i >= 0; --i) {
if (d[f[u][i]] >= d[v]) u = f[u][i];
}
if (u == v) return u;
for (int i = 18; i >= 0; --i) {
if (f[u][i] != f[v][i])
u = f[u][i], v = f[v][i];
}
return f[u][0];
}
int query(int x,int y,int z,int p,int l,int r,int k){
if(l == r) return r;
int cnt = tr[tr[x].l].cnt + tr[tr[y].l].cnt - tr[tr[z].l].cnt - tr[tr[p].l].cnt;
int mid = l + r >> 1;
if(k <= cnt) return query(tr[x].l,tr[y].l,tr[z].l,tr[p].l,l,mid,k);
else return query(tr[x].r,tr[y].r,tr[z].r,tr[p].r,mid + 1,r,k - cnt);
}
int main(){
memset(h,-1,sizeof h);
scanf("%d%d",&n,&m);
for(int i = 1;i <= n;i++) {
scanf("%d",&a[i]);
nums.push_back(a[i]);
}
for(int i = 1;i < n;i++){
int x,y;
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
sort(nums.begin(),nums.end());
nums.erase(unique(nums.begin(),nums.end()),nums.end());
root[0] = build(0,nums.size() - 1);
dfs(1,0);
int last = 0;
while(m--){
int l,r,k;
scanf("%d%d%d",&l,&r,&k);
int z = LCA(l,r);
last = nums[query(root[l],root[r],root[z],root[f[z][0]],0,nums.size() - 1,k)];
printf("%d\n",last);
}
return 0;
}