题目链接:点我啊╭(╯^╰)╮
题目大意:
一棵树,两种操作:
①:在点
放
个蘑菇
②:将起点变为 v$
每次操作后计算起点收集所有蘑菇的代价
收集一个蘑菇的代价为起点到终点最短路径上的第一条边权
解题思路:
计算一个点
的答案,分三部分计算
重儿子的所在子树的所有答案
这些蘑菇都以
到重儿子这条边的边权为代价
所以每次更新点
时,维护
到根节点的重链上的蘑菇数
所有轻儿子所在子树的答案
这部分在第一步时,轻重链交换时暴力统计
以父亲边边权为代价的答案
用总蘑菇数
上两种情况的蘑菇和即可
核心:轻重链剖分的应用
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
#define fi first
#define se second
using namespace std;
typedef long long ll;
using pii = pair <ll,ll>;
const int maxn = 1e6 + 5;
int n, q, dep[maxn], fa[maxn], fv[maxn], size[maxn];
ll sum, t[maxn<<2], lz[maxn<<2], cnt[maxn];
int dfn[maxn], id[maxn], tot, son[maxn], top[maxn];
vector <pii> g[maxn];
pii ans[maxn];
void dfs1(int u, int f, int de) {
dep[u] = de, fa[u] = f, size[u] = 1;
for(auto tmp : g[u]) {
int v = tmp.fi;
int w = tmp.se;
if(v == f) continue;
dfs1(v, u, de+1);
fv[v] = w;
size[u] += size[v];
if(size[son[u]] < size[v]) son[u] = v;
}
}
void dfs2(int u, int tp) {
top[u] = tp, dfn[++tot] = u, id[u] = tot;
if(son[u]) dfs2(son[u], tp);
for(auto tmp : g[u]) {
int v = tmp.fi;
if(v == fa[u]) continue;
if(v == son[u]) continue;
dfs2(v, v);
}
}
void pushdown(int rt) {
if(lz[rt]) {
t[rt<<1] += lz[rt];
t[rt<<1|1] += lz[rt];
lz[rt<<1] += lz[rt];
lz[rt<<1|1] += lz[rt];
lz[rt] = 0;
}
}
void update(ll x, int L, int R, int l, int r, int rt) {
if(l>R || r<L) return;
if(l>=L && r<=R) {
t[rt] += x;
lz[rt] += x;
return;
}
pushdown(rt);
int mid = l + r >> 1;
update(x, L, R, l, mid, rt<<1);
update(x, L, R, mid+1, r, rt<<1|1);
t[rt] = t[rt<<1] + t[rt<<1|1];
}
ll query(int pos, int l, int r, int rt) {
if(pos>r || pos<l) return 0;
if(l == r) return t[rt];
pushdown(rt);
int mid = l + r >> 1; ll ret = 0;
ret += query(pos, l, mid, rt<<1);
ret += query(pos, mid+1, r, rt<<1|1);
return ret;
}
void gao(int u, int x) {
while(u) {
update(x, id[top[u]], id[u], 1, n, 1);
u = top[u];
ans[fa[u]].fi += 1ll * x * fv[u];
ans[fa[u]].se += x;
u = fa[u];
}
}
void solve(int u) {
ll res = 0, num = query(id[son[u]], 1, n, 1);
res += 1ll * num * fv[son[u]];
res += 1ll * (sum - cnt[u] - num - ans[u].se) * fv[u];
res += ans[u].fi;
printf("%lld\n", res);
}
int main() {
scanf("%d", &n);
for(int i=1, u, v, w; i<n; i++) {
scanf("%d%d%d", &u, &v, &w);
g[u].push_back({v, w});
g[v].push_back({u, w});
}
dfs1(1, 0, 1);
dfs2(1, 1);
scanf("%d", &q);
int op, v, x, rt = 1;
while(q--) {
scanf("%d", &op);
if(op == 1) {
scanf("%d%d", &v, &x);
sum += x;
cnt[v] += x;
gao(v, x);
} else scanf("%d", &rt);
solve(rt);
}
}