计蒜客 2019南昌邀请网络赛J Distance on the tree(主席树)题解

题意:给出一棵树,给出每条边的权值,现在给出m个询问,要你每次输出u~v的最短路径中,边权 <= k 的边有几条

思路:当时网络赛的时候没学过主席树,现在补上。先树上建主席树,然后把边权交给子节点,然后数量就变成了 u + v - lca * 2。专题里那道算点权的应该算原题吧。1A = =,强行做模板题提高自信。

代码:

#include<cmath>
#include<set>
#include<map>
#include<queue>
#include<cstdio>
#include<vector>
#include<cstring>
#include <iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const int maxn = 1e5 + 10;
const int M = maxn * 30;
const ull seed = 131;
const int INF = 0x3f3f3f3f;
const int MOD = 1e9 + 7;
int n, m;
int root[maxn], tot;
struct Edge{
    int v, next;
    ll w;
}edge[maxn << 1];
int head[maxn], tol;
void addEdge(int u, int v, ll w){
    edge[tol].v = v;
    edge[tol].w = w;
    edge[tol].next = head[u];
    head[u] = tol++;
}
struct node{
    int lson, rson;
    int sum;
}T[maxn * 40];
void init(){
    memset(T, 0, sizeof(T));
    memset(root, 0, sizeof(root));
    memset(head, -1, sizeof(head));
    tot = tol = 0;
}
vector<int> vv;
int getid(int x){
    return lower_bound(vv.begin(), vv.end(), x) - vv.begin() + 1;
}
void update(int l, int r, int &now, int pre, int v, int pos){
    T[++tot] = T[pre], T[tot].sum += v, now = tot;
    if(l == r) return;
    int m = (l + r) >> 1;
    if(pos <= m)
        update(l, m, T[now].lson, T[pre].lson, v, pos);
    else
        update(m + 1, r, T[now].rson, T[pre].rson, v, pos);
}
void build(int now, int pre, ll w){
    update(1, vv.size(), root[now], root[pre], 1, getid(w));
    for(int i = head[now]; i != -1; i = edge[i].next){
        int v = edge[i].v;
        if(v == pre) continue;
        build(v, now, edge[i].w);
    }
}
int query(int l, int r, int now, int pre, int lca, int k){
    if(l == r){
        if(k >= l) return T[now].sum + T[pre].sum - T[lca].sum * 2;
        return 0;
    }
    if(r <= k) return T[now].sum + T[pre].sum - T[lca].sum * 2;
    int m = (l + r) >> 1;
    int sum = 0;
    if(k <= m)
        return query(l, m, T[now].lson, T[pre].lson, T[lca].lson, k);
    else{
        sum = query(m + 1, r, T[now].rson, T[pre].rson, T[lca].rson, k);
        return sum + T[T[now].lson].sum + T[T[pre].lson].sum - T[T[lca].lson].sum * 2;
    }
}

//lca
int fa[maxn][20];
int dep[maxn];
void lca_dfs(int u, int pre, int d){
    dep[u] = d;
    fa[u][0] = pre;
    for(int i = head[u]; i != -1; i = edge[i].next){
        int v = edge[i].v;
        if(v != pre)
            lca_dfs(v, u, d + 1);
    }
}
void lca_update(){
    for(int i = 1; (1 << i) <= n; i++){
        for(int u = 1; u <= n; u++){
            fa[u][i] = fa[fa[u][i - 1]][i - 1];
        }
    }
}
int lca_query(int u, int v){
    if(dep[u] < dep[v]) swap(u, v);
    int d = dep[u] - dep[v];
    for(int i = 0; (1 << i) <= d; i++){
        if(d & (1 << i)){
            u = fa[u][i];
        }
    }
    if(u != v){
        for(int i = (int)log2(n); i >= 0; i--){
            if(fa[u][i] != fa[v][i]){
                u = fa[u][i];
                v = fa[v][i];
            }
        }
        u = fa[u][0];
    }
    return u;
}
int u1[maxn], v1[maxn];
ll k1[maxn];
int main(){
    init();
    vv.clear();
    scanf("%d%d", &n, &m);
    vv.push_back(0);
    for(int i = 1; i <= n - 1; i++){
        int u, v;
        ll w;
        scanf("%d%d%lld", &u, &v, &w);
        addEdge(u, v, w);
        addEdge(v, u, w);
        vv.push_back(w);
    }
    for(int i = 1; i <= m; i++){
        scanf("%d%d%lld", &u1[i], &v1[i], &k1[i]);
        vv.push_back(k1[i]);
    }
    sort(vv.begin(), vv.end());
    vv.erase(unique(vv.begin(), vv.end()), vv.end());
    lca_dfs(1, 0, 1);
    lca_update();
    build(1, 0, 0);
    for(int i = 1; i <= m; i++){
        int lca = lca_query(u1[i], v1[i]);
        printf("%d\n", query(1, vv.size(), root[u1[i]], root[v1[i]], root[lca], getid(k1[i])));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/KirinSB/p/10897172.html