P3384 [template] chain split tree

 

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

int n, m, rt, mod, cnt, tot;
int val[100005];
int dep[100005];
int id[100005], ed[100005];
int rk[100005];
int sz[100005];
int fa[100005];
int son[100005];
int top[100005];
ll sum[400005];
ll lz[400005];

struct node {
    int to, nex;
}E[200005];
int head[100005];

void dfs1(int x, int pre, int d) {
    sz[x] = 1;
    dep[x] = d;
    fa[x] = pre;
    for(int i = head[x]; i; i = E[i].nex) {
        int v = E[i].to;
        if(v == pre) continue;

        dfs1(v, x, d + 1);
        sz[x] += sz[v];
        if(sz[son[x]] < sz[v]) son[x] = v;
    }
}

void dfs2(int x, int t) {
    top[x] = t;
    id[x] = ++tot;
    rk[tot] = x;
    if(!son[x]) {
        ed[x] = tot;
        return;
    }

    dfs2(son[x], t);
    for(int i = head[x]; i; i = E[i].nex) {
        int v = E[i].to;
        if(v == son[x] || v == fa[x]) continue;
        dfs2(v, v);
    }
    ed[x] = tot;
}

void pushup(int rt) {
    sum[rt] = (sum[rt << 1] + sum[rt << 1 | 1]) % mod;
}

void pushdown(int l, int r, int rt) {
    if(lz[rt]) {
        int m = l + r >> 1;
        sum[rt << 1] = (sum[rt << 1] + 1LL * (m - l + 1) * lz[rt] % mod) % mod;
        sum[rt << 1 | 1] = (sum[rt << 1 | 1] + 1LL * (r - m) * lz[rt] % mod) % mod;
        lz[rt << 1] = (lz[rt << 1] + lz[rt]) % mod;
        lz[rt << 1 | 1] = (lz[rt << 1 | 1] + lz[rt]) % mod;
        lz[rt] = 0;
    }
}

void build(int l, int r, int rt) {
    if(l == r) {
        sum[rt] = 1LL * val[rk[l]] % mod;
        return;
    }

    int m = l + r >> 1;
    build(l, m, rt << 1);
    build(m + 1, r, rt << 1 | 1);
    pushup(rt);
}

void update(int ql, int qr, ll val, int l, int r, int rt) {
    if(ql <= l && qr >= r) {
        sum[rt] += 1LL * (r - l + 1) * val % mod;
        sum[rt] %= mod;
        lz[rt] += val;
        lz[rt] %= mod;
        return;
    }

    pushdown(l, r, rt);
    int m = l + r >> 1;
    if(ql <= m) update(ql, qr, val, l, m, rt << 1);
    if(qr > m) update(ql, qr, val, m + 1, r, rt << 1 | 1);
    pushup(rt);
}

ll query(int ql, int qr, int l, int r, int rt) {
    if(ql <= l && qr >= r) return sum[rt];

    pushdown(l, r, rt);
    ll res = 0;

    int m = l + r >> 1;
    if(ql <= m) res += query(ql, qr, l, m, rt << 1);
    if(qr > m) res += query(ql, qr, m + 1, r, rt << 1 | 1);
    res %= mod;
    return res;
}

ll cal_sum(int x, int y) {
    ll res = 0;
    int fx = top[x];
    int fy = top[y];
    while(fx != fy) {
        if(dep[fx] >= dep[fy]) {
            res += query(id[fx], id[x], 1, n, 1);
            res %= mod;
            x = fa[fx]; fx = top[x];
        } else {
            res += query(id[fy], id[y], 1, n, 1);
            res %= mod;
            y = fa[fy]; fy = top[y];
        }
    }
    if(id[x] <= id[y]) res += query(id[x], id[y], 1, n, 1);
    else res += query(id[y], id[x], 1, n, 1);
    res %= mod;

    return res;
}

void update_lian(int x, int y, ll val) {
    int fx = top[x];
    int fy = top[y];
    while(fx != fy) {
        if(dep[fx] >= dep[fy]) {
            update(id[fx], id[x], val, 1, n, 1);
            x = fa[fx]; fx = top[x];
        } else {
            update(id[fy], id[y], val, 1, n, 1);
            y = fa[fy]; fy = top[y];
        }
    }
    if(id[x] <= id[y]) update(id[x], id[y], val, 1, n, 1);
    else update(id[y], id[x], val, 1, n, 1);
}

int main() {
    scanf("%d%d%d%d", &n, &m, &rt, &mod);
    cnt = 0;
    tot = 0;
    for(int i = 1; i <= n; i++) scanf("%d", &val[i]);
    for(int i = 1; i < n; i++) {
        int x, y;
        scanf("%d%d", &x, &y);
        E[++cnt].to = y; E[cnt].nex = head[x]; head[x] = cnt;
        E[++cnt].to = x; E[cnt].nex = head[y]; head[y] = cnt;
    }
    dfs1(rt, 0, 1);
    dfs2(rt, rt);
    build(1, n, 1);
    while(m--) {
        int opt;
        scanf("%d", &opt);

        int a, b, c;
        if(opt == 1) {
            scanf("%d%d%d", &a, &b, &c);
            c %= mod;
            update_lian(a, b, 1LL * c);
        } else if(opt == 2) {
            scanf("%d%d", &a, &b);
            printf("%lld\n", cal_sum(a, b));
        } else if(opt == 3) {
            scanf("%d%d", &a, &b);
            b %= mod;
            update(id[a], ed[a], 1LL * b, 1, n, 1);
        } else {
            scanf("%d", &a);
            printf("%lld\n", query(id[a], ed[a], 1, n, 1));
        }
    }
    return 0;
}
View Code

 

Guess you like

Origin www.cnblogs.com/lwqq3/p/11141452.html