牛客挑战赛21 E 题 未来城市规划 【树链剖分 + 线段树 + 思维】

传送门
题目大意: 给定一颗以1为根有根树, 树上有边权, 每次询问有两种操作,
INC u v w 代表把u到v的路径上所有边权都增加w
ASK P 询问以P为根的子树内任意两点之间的距离和

思路: 增加操作很好做, 直接在树剖上做就可以了, 关键时怎么处理查询, 很明显我们要找每条边被覆盖的次数的关系来解决, 在树上的一条边, 如果考虑所有经过节点对数经过它次数, 假设连接这条边的时节点u, v那么次数就等于siz[v](siz[u]-siz[v]), 其中siz[u] 表示u的子树内的节点数, u时v的父亲, 所以展开有(siz[u]*siz[v] + siz[v]^2) w, 就等于siz[u]*siz[v]*w+ siz[v]^2*w, 所以我们线段树就可以维护5个值, s1代表这个区间表示的节点的siz和, s2是siz的平方和, v1是siz[v]*w的答案, v2是siz[v]^2*w的答案, 然后还有一个lazy标记即可. 注意每次我们增加一个边权时, 推导也可知我们只会改变v1, v2的值, 即s1, s2后面都不会有变化的. 具体细节请看代码.

注意: 我们是将边权转化为了点权做的, 所以部分区间应该找对!!! 因为是减法%, 注意最后是+mod%mod!!! 减法%都要注意这个问题啊!!!

【权值在边上 + 区间加 + 询问区间和 + 树上边被覆盖的次数推导!!!】
AC Code

const int maxn = 5e4 + 5;
int n, m, cnt, head[maxn], tim;
int siz[maxn], top[maxn];
int son[maxn], dep[maxn], fa[maxn];
int dd[maxn], dis[maxn]; // 表示把那条边换到哪个点上.
ll a[maxn];
int tid[maxn], out[maxn], pos[maxn];
struct Tree {
    int tl, tr;
    ll s1, s2, lazy;
    ll v1, v2;
    void fun(ll tmp) {
        lazy = (lazy + tmp) % mod;
        v1 = (v1 + s1*tmp) % mod;
        v2 = (v2 + s2*tmp) % mod;
    }
}tre[maxn<<2];
void pushup(int id) {
    tre[id].v1 = (tre[id<<1].v1 + tre[id<<1|1].v1) % mod;
    tre[id].v2 = (tre[id<<1].v2 + tre[id<<1|1].v2) % mod;
}
void pushdown(int id) {
    if(tre[id].lazy) {
        tre[id<<1].fun(tre[id].lazy);
        tre[id<<1|1].fun(tre[id].lazy);
        tre[id].lazy = 0;
    }
}
void build(int id,int l,int r) {
    tre[id].tl = l; tre[id].tr = r; tre[id].lazy = 0;
    tre[id].s1 = tre[id].s2 = 0;
    if(l == r) {
        tre[id].s1 = siz[pos[l]] % mod;
        tre[id].s2 = tre[id].s1 * tre[id].s1 % mod;
        tre[id].v1 = a[pos[l]] * tre[id].s1 % mod;
        tre[id].v2 = a[pos[l]] * tre[id].s2 % mod;
        return ;
    }
    int mid = (l+r) >> 1;
    build(id<<1, l, mid);
    build(id<<1|1, mid+1, r);
    tre[id].s1 = (tre[id<<1].s1 + tre[id<<1|1].s1) % mod;
    tre[id].s2 = (tre[id<<1].s2 + tre[id<<1|1].s2) % mod;
    pushup(id);
}
void update(int id, int ul, int ur, ll val) {
    int l = tre[id].tl, r = tre[id].tr;
    if(ul <= l && r <= ur) {
        tre[id].fun(val);
        return ;
    }
    pushdown(id);
    int mid = (l+r) >> 1;
    if(ul <= mid) update(id<<1, ul, ur, val);
    if(ur > mid) update(id<<1|1, ul, ur, val);
    pushup(id);
}
ll query_sum(int id, int ql, int qr, int f) {
    int l = tre[id].tl , r = tre[id].tr;
    if(ql <= l && r <= qr) {
        return f ? tre[id].v1 : tre[id].v2;
    }
    pushdown(id);
    int mid = (l+r) >> 1 ;
    if(qr <= mid) return query_sum(id<<1, ql, qr, f);
    else if(ql > mid) return query_sum(id<<1|1, ql, qr, f);
    else return query_sum(id<<1, ql, mid, f) + query_sum(id<<1|1, mid+1, qr, f);
}
struct node {
    int to, next, idx;
}e[maxn<<1];
void add(int u, int v, int id) {
    e[cnt] = node{v, head[u], id};
    head[u] = cnt++;
}
void init() {
    cnt = 0; Fill(head, -1);
    tim = 0; Fill(son, -1);
}
void dfs1(int u, int f, int deep) {
    dep[u] = deep + 1; siz[u] = 1;
    for (int i = head[u] ; ~i ; i = e[i].next) {
        int to = e[i].to;
        if (to == f) continue;
        fa[to] = u;
        dd[e[i].idx] = to;
        dfs1(to, u, deep+1);
        siz[u] += siz[to];
        if (son[u] == -1 || siz[to] > siz[son[u]]) {
            son[u] = to;
        }
    }
}
void dfs2(int u, int tp) {
    top[u] = tp;
    tid[u] = ++tim;
    pos[tim] = u;
    if (son[u] == -1) {
        out[u] = tim;
        return ;
    }
    dfs2(son[u], tp);
    for (int i = head[u] ; ~i ; i = e[i].next) {
        int to = e[i].to;
        if (to != son[u] && to != fa[u]) {
            dfs2(to, to);
        }
    }
    out[u] = tim;
}
void add_val(int x, int y, ll val) {
    for (;top[x] != top[y] ; x = fa[top[x]]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        update(1, tid[top[x]], tid[x], val);
    }
    if (x == y) return ;
    if (dep[x] > dep[y]) swap(x, y);
    if (tid[son[x]] > tid[y]) swap(x, y);
    update(1, tid[son[x]], tid[y], val);
}
ll get_sum(int u) {
    return 1ll*siz[u]*query_sum(1, tid[u]+1, out[u], 1) \
            - query_sum(1, tid[u]+1, out[u], 0);
}
ll val[maxn];
void solve() {
    scanf("%d%d", &n, &m); init();
    for (int i = 2 ; i <= n ; i ++) {
        int u, w;
        scanf("%d%d", &u, &w);
        add(u, i, i); add(i, u, i);
        val[i] = 1ll*w;
    }
    dfs1(1, -1, 0); dfs2(1, 1);
    for (int i = 2 ; i <= n ; i ++) {
        a[dd[i]] = val[i];
    }
    build(1, 1, n);
    char op[10];
    while(m--) {
        scanf("%s", op);
        int u, v, w;
        if (op[0] == 'I') {
            scanf("%d%d%d", &u, &v, &w);
            if (u == v) continue;
            add_val(u, v, 1ll*w);
        }
        else {
            scanf("%d", &u);
            if (tid[u] == out[u]) puts("0");
            else {
                ll ans = get_sum(u);
                ans = ((ans%mod) + mod)% mod;
                printf("%lld\n", ans);
            }
        }
    }
}

猜你喜欢

转载自blog.csdn.net/Anxdada/article/details/81412029