题目链接:https://cn.vjudge.net/contest/290094#problem/B
题目大意:给你一棵树,每个节点都有一个权值,接下来q次询问,给你两个节点,询问这两个节点之间的路上第k大的权值是多少。
题解:假设两个节点为x,y,首先利用倍增lca求出这两个节点之间的lca,每个节点的主席树是以父亲节点为根建立的,两个节点之间的路的主席树为x+y-lca(x,y)-fa[lca(x,y)]。然后求区间第k大就行了。
代码:
#include <bits/stdc++.h>
using namespace std;
struct node
{
int data;
int l,r;
}tree[5000005];
int head[100005],tot;
vector<int>g[100005];
int s[100005],xx[100005],ans=0,fa[100005][40],depth[100005],lg[100005],n;
void init()
{
tree[0].l=tree[0].r=tot=tree[0].data=0;
}
void add(int l,int r,int &x,int y,int pos)
{
tree[++tot]=tree[y];tree[tot].data++;x=tot;
if(l==r) return ;
int mid=(l+r)>>1;
if(pos<=mid) add(l,mid,tree[x].l,tree[y].l,pos);
else add(mid+1,r,tree[x].r,tree[y].r,pos);
}
int query(int l,int r,int x,int y,int rt,int fart,int k)
{
if(l==r) return l;
int num=tree[tree[x].l].data+tree[tree[y].l].data-tree[tree[rt].l].data-tree[tree[fart].l].data;
int mid=(l+r)>>1;
if(num>=k) return query(l,mid,tree[x].l,tree[y].l,tree[rt].l,tree[fart].l,k);
else return query(mid+1,r,tree[x].r,tree[y].r,tree[rt].r,tree[fart].r,k-num);
}
int erfen(int num)
{
int l=1,r=ans;
while(l<=r)
{
int mid=(l+r)>>1;
if(s[mid]==num) return mid;
else if(s[mid]>num) r=mid-1;
else l=mid+1;
}
return 0;
}
void dfs(int now,int fath)
{
depth[now]=depth[fath]+1;
fa[now][0]=fath;
add(1,n,head[now],head[fath],erfen(xx[now]));
for(int i=1;(1<<i)<=depth[now];i++)
fa[now][i]=fa[fa[now][i-1]][i-1];
int len=g[now].size();
for(int i=0;i<len;i++)
{
int v=g[now][i];
if(v!=fath) dfs(v,now);
}
}
int lca(int a,int b)
{
if(depth[a]<depth[b]) swap(a,b);
while(depth[a]>depth[b]) a=fa[a][lg[depth[a]-depth[b]]-1];
if(a==b) return a;
for(int i=lg[depth[a]-1];i>=0;i--)
{
if(fa[a][i]!=fa[b][i])
{
a=fa[a][i];
b=fa[b][i];
}
}
return fa[a][0];
}
int main()
{
int m;
cin>>n>>m;
for(int i=1;i<=n;i++) scanf("%d",&s[i]),xx[i]=s[i];
sort(s+1,s+1+n);
s[0]=-1;
for(int i=1;i<=n;i++) if(s[i]!=s[i-1]) s[++ans]=s[i];
int u,v;
for(int i=1;i<=n-1;i++)
{
scanf("%d %d",&u,&v);
g[u].push_back(v);
g[v].push_back(u);
}
for(int i=1;i<=m+1;i++) lg[i]=lg[i-1]+(1<<lg[i-1]==i);
memset(fa,0,sizeof(fa));
memset(head,0,sizeof(head));
depth[0]=0;
dfs(1,0);
int k;
for(int i=1;i<=m;i++)
{
scanf("%d %d %d",&u,&v,&k);
int fath=lca(u,v);
printf("%d\n",s[query(1,n,head[u],head[v],head[fath],head[fa[fath][0]],k)]);
}
return 0;
}