## 题面

• \(u = rt\), 直接修改整棵树即可

• \(rt\)\(u\)的祖先或\(rt\)\(u\)在原先为\(1\)的两棵不同子树中, 修改\(u\)子树即可
• \(u\)\(rt\)的祖先, 先修改整棵树, 然后找出\(rt\)的儿子中子树包含\(u\)的那一个儿子, 将以这个儿子为根的子树中撤销修改操作即可

### Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#define itn int
#define N 300005
using namespace std;

int n, m, w[N], cnt, head[N], f[N][21], sz[N], dep[N], son[N], top[N], pre[N], dfn[N], rt = 1;
struct edge { int to, next; } e[N << 1];
struct Tree { long long sum, tag; } t[N << 2];

{
int x = 0, w = 1; char c = getchar();
while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
return x * w;
}

inline void adde(int u, int v) { e[++cnt] = (edge) { v, head[u] }; head[u] = cnt; }

void dfs1(int u, int fa)
{
f[u][0] = fa;
dep[u] = dep[fa] + 1; sz[u] = 1;
for(int i = 1; i <= 20; i++)
f[u][i] = f[f[u][i - 1]][i - 1];
for(int i = head[u]; i; i = e[i].next)
{
int v = e[i].to; if(v == fa) continue;
dfs1(v, u);
sz[u] += sz[v]; if(sz[son[u]] < sz[v]) son[u] = v;
}
}

void dfs2(int x, int y)
{
top[pre[dfn[x] = ++cnt] = x] = y;
if(!son[x]) return; dfs2(son[x], y);
for(int i = head[x]; i; i = e[i].next) if(e[i].to != son[x] && e[i].to != f[x][0]) dfs2(e[i].to, e[i].to);
}

void build(int p, int l, int r)
{
if(l == r) return (void) (t[p].sum = w[pre[l]], t[p].tag = 0);
int mid = (l + r) >> 1;
build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
}

void pushdown(int p, int l, int r)
{
if(t[p].tag)
{
int mid = (l + r) >> 1;
t[p << 1].tag += t[p].tag; t[p << 1].sum += 1ll * t[p].tag * (mid - l + 1);
t[p << 1 | 1].tag += t[p].tag; t[p << 1 | 1].sum += 1ll * t[p].tag * (r - mid);
t[p].tag = 0;
}
}

void modify(int p, int l, int r, int ql, int qr, int k)
{
if(ql <= l && r <= qr) return (void) (t[p].sum += 1ll * k * (r - l + 1), t[p].tag += k);
pushdown(p, l, r);
int mid = (l + r) >> 1;
if(ql <= mid) modify(p << 1, l, mid, ql, qr, k);
if(mid < qr) modify(p << 1 | 1, mid + 1, r, ql, qr, k);
t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
}

int LCA(int x, int y)
{
while(top[x] != top[y])
{
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = f[top[x]][0];
}
return dep[x] < dep[y] ? x : y;
}

int finds(int u, int v)
{
if(dep[u] < dep[v]) swap(u, v);
for(int i = 20; i >= 0; i--)
if(dep[f[u][i]] > dep[v]) u = f[u][i];
return u;
}

long long query(int p, int l, int r, int ql, int qr)
{
if(ql <= l && r <= qr) return t[p].sum;
pushdown(p, l, r);
int mid = (l + r) >> 1; long long ans = 0;
if(ql <= mid) ans += query(p << 1, l, mid, ql, qr);
if(mid < qr) ans += query(p << 1 | 1, mid + 1, r, ql, qr);
t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum;
return ans;
}

int main()
{
/*  freopen("A.in", "r", stdin);
freopen("A.out", "w", stdout);
for(int i = 1; i <= n; i++) w[i] = read();
for(int i = 1; i < n; i++)
{
}
cnt = 0; dfs1(1, 0); dfs2(1, 1);
build(1, 1, n);
for(int i = 1; i <= m; i++)
{
if(opt == 1) rt = read();
if(opt == 2)
{
if(LCA(lca, rt) == lca)
{
int lcau = LCA(rt, u), lcav = LCA(rt, v);
lcau = dep[lcau] < dep[lcav] ? lcav : lcau;
modify(1, 1, n, 1, n, x);
if(rt != lcau)
{
int s = finds(rt, lcau);
modify(1, 1, n, dfn[s], dfn[s] + sz[s] - 1, -x);
}
}
else modify(1, 1, n, dfn[lca], dfn[lca] + sz[lca] - 1, x);
}
if(opt == 3)
{
if(u == rt) printf("%I64d\n", query(1, 1, n, 1, n));
else
{
int lca = LCA(u, rt);
if(lca == u)
{
int s = finds(rt, u);
printf("%I64d\n", query(1, 1, n, 1, n) - query(1, 1, n, dfn[s], dfn[s] + sz[s] - 1));
}
else printf("%I64d\n", query(1, 1, n, dfn[u], dfn[u] + sz[u] - 1));
}
}
}
return 0;
}

