「Luogu P3178」[HAOI2015]树上操作

有一棵点数为 \(N\) 的树,以点 \(1\) 为根,且树点有边权。然后有 \(M\) 个操作,分为三种:

  • 操作 1 :把某个节点 \(x\) 的点权增加 \(a\)
  • 操作 2 :把某个节点 \(x\) 为根的子树中所有点的点权都增加 \(a\)
  • 操作 3 :询问某个节点 \(x\) 到根的路径中所有点的点权和。

Luogu

分析

我们把树上问题利用 \(dfs\) 序转化成序列问题然后直接上线段树解决即可。

考虑将线段树的每个叶子结点设为在原树上的点到根的点权和。对于单点修改,当前结点增加了 \(a\) ,那么它和它的子树内的结点到根的点权和都会增加 \(a\) 。而对于整个子树的修改,如果增加的值为 \(a\) ,那么 \(x\) 的子树内的结点的答案的增加量为它到 \(x\) 路径上的结点数与 \(a\) 的乘积,设 \(dep_i\)\(i\) 的深度,那么上面的答案即为 \(a\times(dep_u-dep_x+1)\)\(u\)\(x\) 子树内的结点,为了方便,我们将上式分成 \(a\times(1-dep_x)\)\(a\times dep_u\) 两部分,把 \(a\times (1-dep_x)\) 在线段树内打上标记,在询问时再加上 \(a\times dep_u\)

代码

#include <bits/stdc++.h>

#define N 100003
#define ls o<<1
#define rs o<<1|1
#define ll long long

using namespace std;

int gi() {
    int x = 0, f = 1; char c = getchar();
    for ( ; !isdigit(c); c = getchar()) if (c == '-') f = -1;
    for ( ; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
    return x * f;
}

int n, m, tot;
int nxt[N], hd[N], cnt;
int val[N], dep[N], dfn[N], st[N], ed[N];

struct SegmentTree {
    ll add[N << 2], mul[N << 2];
    void pushdown(int o) {
        add[ls] += add[o], add[rs] += add[o];
        mul[ls] += mul[o], mul[rs] += mul[o];
        add[o] = mul[o] = 0;
    }
    void modify(int o, int l, int r, int L, int R, ll x, ll y) {
        if (l > R || r < L) return;
        if (l >= L && r <= R) {
            add[o] += x, mul[o] += y;
            return;
        }
        pushdown(o);
        int mid = l + r >> 1;
        modify(ls, l, mid, L, R, x, y), modify(rs, mid + 1, r, L, R, x, y);
    }
    ll query(int o, int l, int r, int p) {
        if (l == r) return add[o] + 1ll * mul[o] * dep[dfn[p]];
        pushdown(o);
        int mid = l + r >> 1;
        if (p <= mid) return query(ls, l, mid, p);
        else return query(rs, mid + 1, r, p);
    }
} tr;

void insert(int u, int v) { nxt[u] = hd[v], hd[v] = u; }

void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1, dfn[++tot] = u, st[u] = tot;
    for (int v = hd[u]; v; v = nxt[v]) dfs(v, u);
    ed[u] = tot;
}

int main() {
    int opt, x, a, u, v;
    n = gi(), m = gi();
    for (int i = 1; i <= n; ++i) val[i] = gi();
    for (int i = 1; i < n; ++i) {
        u = gi(), v = gi();
        insert(u, v);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; ++i) tr.modify(1, 1, n, st[i], ed[i], val[i], 0);
    for (int i = 1; i <= m; ++i) {
        opt = gi(), x = gi();
        if (opt == 1) tr.modify(1, 1, n, st[x], ed[x], gi(), 0);
        else if (opt == 2) {
            a = gi();
            tr.modify(1, 1, n, st[x], ed[x], 1ll * a * (1 - dep[x]), a);
        }
        else printf("%lld\n", tr.query(1, 1, n, st[x]));
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/hlw1/p/12286038.html