H - Traffic Network in Numazu HDU - 6393(基环树)

Traffic Network in Numazu (HDU - 6393)

题意:给定一张\(n\)个点\(n\)条边的带权图。要求支持两种操作:

  • \(0\ x\ y :\)修改第\(x\)条边的权值为\(y\)
  • \(1\ x\ y :\)查询\((x,y)\)的最短路。

题解:

\(n\)个点\(n\)条边,就是一颗基环树。我们可以拆掉基环树上的一条边,变为一棵树。那么两个点的最短路就是树上的距离和经过拆掉的边的距离,取最小值。

对于树上的距离,我们可以用线段树维护每个点到根节点的距离来求出。每次修改一条边的权值时,就把这条边以下的子树整体修改。

对于\((x,y)\)经过被拆掉的边\((u,v,len)\)的情况,我们可以在\((x, u)+(y, v)+len\)\((x,v)+(y,u)+len\)\(min\)

最后把上面两种情况取\(min\)就是询问的答案。

代码:

#include <bits/stdc++.h>
#define fopi freopen("in.txt", "r", stdin)
#define fopo freopen("out.txt", "w", stdout)
using namespace std;
typedef long long LL;
typedef pair<int, LL> Pair;
const int inf = 0x3f3f3f3f;
const int maxn = 1e5 + 10;

LL d[maxn];
struct SegTree {
    struct Node {
        int l, r;
        LL sum, add;
    }t[maxn*4];

    void build(int id, int l, int r) {
        t[id].l = l, t[id].r = r;
        t[id].add = 0;
        if (l == r) {
            t[id].sum = d[t[id].l];
            return;
        }
        int mid = (l+r) / 2;
        build(id*2, l, mid);
        build(id*2+1, mid+1, r);
        t[id].sum = t[id*2].sum + t[id*2+1].sum;
    }

    void pushdown(int id) {
        if (t[id].add != 0) {
            t[id*2].add += t[id].add;
            t[id*2+1].add += t[id].add;
            int mid = (t[id].l + t[id].r) / 2;
            t[id*2].sum += t[id].add * (mid-t[id].l+1);
            t[id*2+1].sum += t[id].add * (t[id].r-mid);
            t[id].add = 0;
        }
    }

    void update(int id, int l, int r, LL val) {
        if (l <= t[id].l && r >= t[id].r) {
            t[id].add += val;
            t[id].sum += val * (t[id].r-t[id].l+1);
            return;
        }
        pushdown(id);
        int mid = (t[id].l + t[id].r) / 2;
        if (r <= mid) update(id*2, l, r, val);
        else if (l > mid) update(id*2+1, l, r, val);
        else update(id*2, l, mid, val), update(id*2+1, mid+1, r, val);
        t[id].sum = t[id*2].sum + t[id*2+1].sum;
    }

    LL query(int id, int x) {
        if (t[id].l == x && t[id].r == x) return t[id].sum;
        pushdown(id);
        int mid = (t[id].l + t[id].r) / 2;
        if (x <= mid) query(id*2, x); else query(id*2+1, x);
    }
}ST;

int fa[maxn][22], dep[maxn], dfn[maxn], dfr[maxn];
vector<Pair> V[maxn];
int depth, tot;

void init_lca(int x, int from) {
    dfn[x] = ++tot;
    dep[x] = dep[from] + 1;
    for (auto p : V[x]) {
        int y = p.first, l = p.second;
        if (y == from) continue;
        fa[y][0] = x;
        for (int j = 1; j <= depth; j++)
            fa[y][j] = fa[fa[y][j-1]][j-1];
        init_lca(y, x);
    }
    dfr[x] = tot;
}

void dfs(int x, int from) {
    for (auto p : V[x]) {
        int y = p.first, l = p.second;
        if (y == from) continue;
        d[dfn[y]] = d[dfn[x]] + l;
        dfs(y, x);
    }
}

int lca(int x, int y) {
    if (dep[x] > dep[y]) swap(x, y);
    for (int i = depth; i >= 0; i--)
        if (dep[fa[y][i]] >= dep[x]) y = fa[y][i];
    if (x == y) return x;
    for (int i = depth; i >= 0; i--)
        if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

LL dist(int x, int y) {
    int L = lca(x, y);
    LL d1 = ST.query(1, dfn[x]), d2 = ST.query(1, dfn[y]), d3 = ST.query(1, dfn[L]);
    return d1 + d2 - 2 * d3;
}

int T, n, m;
LL z[maxn];
int y[maxn];
int main() {
    scanf("%d", &T);
    for (int ca = 1; ca <= T; ca++) {
        scanf("%d%d", &n, &m);
        for (int i = 1; i <= n; i++) V[i].clear();

        for (int i = 1; i <= n-1; i++) {
            int x;
            scanf("%d%d%d", &x, &y[i], &z[i]);
            //其实这里应该按照dep的深度存第i条边的儿子节点。
            //所幸题目中没有逆序边,我也没wa。
            V[x].push_back({y[i], z[i]});
            V[y[i]].push_back({x, z[i]});
        }
        int xn, yn;
        scanf("%d%d%d", &xn, &yn, &z[n]);

        depth = 20, tot = 0;
        init_lca(1, 0);
        dfs(1, 0);
        ST.build(1, 1, n);
        for (int i = 1; i <= m; i++) {
            int op, x, val;
            scanf("%d%d%d", &op, &x, &val);
            if (op == 0) {
                if (x == n) { z[n] = val; continue; }
                LL deta = val - z[x];
                ST.update(1, dfn[y[x]], dfr[y[x]], deta);
                z[x] = val;
            }
            else {
                int fx = dfn[x], fy = dfn[val];
                LL d1 = dist(x, val),
                    d2 = dist(x, xn) + dist(val, yn) + z[n],
                    d3 = dist(x, yn) + dist(val, xn) + z[n];
                printf("%lld\n", min(d1, min(d2, d3)));
            }
        }
    }
}

猜你喜欢

转载自www.cnblogs.com/ruthank/p/11369306.html