2019 年百度之星·程序设计大赛 - 初赛四 05

解法:求树上两点路径长度,其实就是 d e p [ u ] + d e p [ v ] 2 d e p [ l c a ] dep[u] +dep[v]-2*dep[lca] ,这里我们采用重链剖分的方式求lca,对于每次修改,都会改变树的结构,我们可以用一颗线段树维护子树的top(重链剖分的链头),这个很简单,懂树剖都会写求lca,怎么求深度?如果当前节点所在的子树没有被修改,那么就是原 d e p dep ,如果被修改了,我们通过线段树找到节点所在的最大的被修改的子树,然后用可持久化线段树查 u u 在该子树中排第几,就能得到该节点的深度,然后水题啦
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
int rt[maxn], ls[maxn * 20], rs[maxn * 20], sum[maxn * 20], cnt, n;
int Set[maxn * 4], L[maxn], R[maxn], id, son[maxn], sz[maxn], Top[maxn];
int f[maxn], dep[maxn];
vector<int> G[maxn];
#define mid (l + r) / 2
void up(int &o, int pre, int l, int r, int k) {
    o = ++cnt;
    ls[o] = ls[pre];
    rs[o] = rs[pre];
    sum[o] = sum[pre] + 1;
    if (l == r)
        return;
    if (k <= mid)
        up(ls[o], ls[pre], l, mid, k);
    else
        up(rs[o], rs[pre], mid + 1, r, k);
}
int qu(int o, int pre, int l, int r, int ql, int qr) {
    if (l >= ql && r <= qr)
        return sum[o] - sum[pre];
    int res = 0;
    if (ql <= mid)
        res += qu(ls[o], ls[pre], l, mid, ql, qr);
    if (qr > mid)
        res += qu(rs[o], rs[pre], mid + 1, r, ql, qr);
    return res;
}
void update(int o, int l, int r, int ql, int qr, int v) {
    if (l >= ql && r <= qr) {
        Set[o] = v;
        return;
    }
    if (ql <= mid)
        update(o * 2, l, mid, ql, qr, v);
    if (qr > mid)
        update(o * 2 + 1, mid + 1, r, ql, qr, v);
}
int query(int o, int l, int r, int k) {
    if (Set[o] || l == r)
        return  Set[o];
    if (k <= mid)
        return query(o * 2, l, mid, k);
    return query(o * 2 + 1, mid + 1, r, k);
}
void dfs1(int u, int fa) {
    L[u] = ++id;
    up(rt[id], rt[id - 1], 1, n, u);
    sz[u] = 1;
    son[u] = 0;
    dep[u] = dep[fa] + 1;
    f[u] = fa;
    for (auto v : G[u])
    if (v != fa) {
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[son[u]] < sz[v])
            son[u] = v;
    }
    R[u] = id;
}
void dfs2(int u, int top) {
    Top[u] = top;
    if (son[u])
        dfs2(son[u], top);
    for (auto v : G[u])
        if (v != f[u] && v != son[u])
            dfs2(v, v);
}
int calc(int u) {
    int tmp = query(1, 1, n, L[u]);
    if (!tmp)
        return dep[u];
    return dep[f[tmp]] + qu(rt[R[tmp]], rt[L[tmp] - 1], 1, n, u, n);
}
int LCA(int u, int v) {
    while (Top[u] != Top[v]) {
        if (dep[Top[u]] < dep[Top[v]])
            swap(u, v);
        u = f[Top[u]];
    }
    if (dep[u] > dep[v])
        swap(u, v);
    return u;
}
int gao(int u, int v) {
    int lca = LCA(u, v);
    int o = query(1, 1, n, L[lca]);
    if (o != 0) {
        int ans = qu(rt[R[o]], rt[L[o] - 1], 1, n, u, n) - qu(rt[R[o]], rt[L[o] - 1], 1, n, v, n);
        return abs(ans);
    }
    return calc(u) + calc(v) - calc(lca) * 2;
}
int main()
{
    int T;
    scanf("%d", &T);
    while (T--) {
        int u, v, q, opt;
        scanf("%d", &n);
        cnt = id = 0;
        for (int i = 1; i <= n; i++)
            G[i].clear();
        for (int i = 1; i <= 4 * n; i++)
            Set[i] = 0;
        for (int i = 1; i < n ;i++) {
            scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        dfs1(1, 0);
        dfs2(1, 1);
        scanf("%d", &q);
        while (q--) {
            scanf("%d%d", &opt, &u);
            if (opt == 1) {
                int cat = query(1, 1, n, L[u]);
                if (!cat)
                    update(1, 1, n, L[u], R[u], u);
            }
            else {
                scanf("%d", &v);
                printf("%d\n", gao(u, v));
            }
        }
    }
}
发布了302 篇原创文章 · 获赞 98 · 访问量 7万+

猜你喜欢

转载自blog.csdn.net/ccsu_cat/article/details/100065231