Count on a tree Chairman tree + LCA

Portal

Title description

Given a tree containing N nodes. Tree nodes are numbered from 1 to N. Each node has an integer weight, and query the weight of the k-th smallest node on the path from node u to node v.

analysis

Introductory questions for the chairman tree on the tree

We can deal with this problem with the idea similar to the difference on the tree, each node is actually updated in the case of its parent node, and then to find the weight on a section of the path, just ask for tr[u] + tr[v]-tr[LCA(u,v)]-tr[fa[LCA(u,v)]], then search for the kth smallest interval

Code

#include <iostream>
#include <cstring>
#include <algorithm>
#include <cstdio>
#include <queue>
#include <vector>

using namespace std;
const int N = 1e5 + 10;
int h[N],ne[N * 2],e[N * 2],pppp;
int n,m;
int a[N];
int d[N],f[N][25];
bool st[N];

vector<int> nums;
struct Node{
    
    
    int l,r;
    int cnt;
}tr[10000005];
int root[N],idx;

void add(int x,int y){
    
    
    ne[pppp] = h[x],e[pppp] = y,h[x] = pppp++;
}

int find(int x){
    
    
    return lower_bound(nums.begin(),nums.end(),x) - nums.begin();
}

int build(int l,int r){
    
    
    int p = ++idx;
    if(l == r) return p;
    int mid = l + r >> 1;
    tr[p].l = build(l,mid),tr[p].r= build(mid + 1,r);
    return p;
}

int insert(int p,int l,int r,int x){
    
    
    int q = ++idx;
    tr[q] = tr[p];
    if(l == r){
    
    
        tr[q].cnt++;
        return q;
    }
    int mid = l + r >> 1;
    if(x <= mid) tr[q].l = insert(tr[p].l,l,mid,x);
    else tr[q].r = insert(tr[p].r,mid + 1,r,x);
    tr[q].cnt = tr[tr[q].l].cnt + tr[tr[q].r].cnt;
    return q;
}

void dfs(int u, int fa) {
    
    
    root[u] = insert(root[fa],0,nums.size() - 1,find(a[u]));
    f[u][0] = fa;
    d[u] = d[fa] + 1;
    for (int i = 1;i <= 18;i++)
        f[u][i] = f[f[u][i-1]][i-1];
    for (int i = h[u];~i;i = ne[i]) {
    
    
        int v = e[i];
        if (v == fa) continue;
        dfs(v, u);
    }
}

int LCA(int u, int v) {
    
    
    if (d[u] < d[v]) swap(u, v);
    for (int i = 18; i >= 0; --i) {
    
    
        if (d[f[u][i]] >= d[v]) u = f[u][i];
    }
    if (u == v) return u;
    for (int i = 18; i >= 0; --i) {
    
    
        if (f[u][i] != f[v][i])
            u = f[u][i], v = f[v][i];
    }
    return f[u][0];
}

int query(int x,int y,int z,int p,int l,int r,int k){
    
    
    if(l == r) return r;
    int cnt = tr[tr[x].l].cnt + tr[tr[y].l].cnt - tr[tr[z].l].cnt - tr[tr[p].l].cnt;
    int mid = l + r >> 1;
    if(k <= cnt) return query(tr[x].l,tr[y].l,tr[z].l,tr[p].l,l,mid,k);
    else return query(tr[x].r,tr[y].r,tr[z].r,tr[p].r,mid + 1,r,k - cnt);
}

int main(){
    
    
    memset(h,-1,sizeof h);
    scanf("%d%d",&n,&m);
    for(int i = 1;i <= n;i++) {
    
    
        scanf("%d",&a[i]);
        nums.push_back(a[i]);
    }
    for(int i = 1;i < n;i++){
    
    
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
    }
    sort(nums.begin(),nums.end());
    nums.erase(unique(nums.begin(),nums.end()),nums.end());
    root[0] = build(0,nums.size() - 1);
    dfs(1,0);
    int last = 0;
    while(m--){
    
    
        int l,r,k;
        scanf("%d%d%d",&l,&r,&k);
        int z = LCA(l,r);
        last = nums[query(root[l],root[r],root[z],root[f[z][0]],0,nums.size() - 1,k)];
        printf("%d\n",last);
    }
    return 0;
}

Guess you like

Origin blog.csdn.net/tlyzxc/article/details/112153740