[题解] [CF916E] Jamie and Tree

题解

题面

如果没有换根操作就直接上树剖加线段树即可

考虑换根操作如何转化

记当前的根节点为\(root\)

子树查询和子树修改类似, 在此只讨论子树查询, 假设当前要修改的是\(u\)子树

  • \(u = rt\), 直接修改整棵树即可

  • \(rt\)\(u\)的祖先或\(rt\)\(u\)在原先为\(1\)的两棵不同子树中, 修改\(u\)子树即可
  • \(u\)\(rt\)的祖先, 先修改整棵树, 然后找出\(rt\)的儿子中子树包含\(u\)的那一个儿子, 将以这个儿子为根的子树中撤销修改操作即可

现在问题就是求两点\(u, v\)在换根后的LCA了

这个点就是以\(1\)为根时\(u, root\)的LCA, \(v, root\)的LCA, \(u, v\)的LCA中深度最大的那个点

找出后按上面修改即可

查询同理

Code

#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <vector>
#define itn int
#define reaD read
#define N 300005
using namespace std;
 
int n, m, w[N], cnt, head[N], f[N][21], sz[N], dep[N], son[N], top[N], pre[N], dfn[N], rt = 1; 
struct edge { int to, next; } e[N << 1];
struct Tree { long long sum, tag; } t[N << 2]; 
 
inline int read()
{
    int x = 0, w = 1; char c = getchar();
    while(c < '0' || c > '9') { if (c == '-') w = -1; c = getchar(); }
    while(c >= '0' && c <= '9') { x = x * 10 + c - '0'; c = getchar(); }
    return x * w;
}
 
inline void adde(int u, int v) { e[++cnt] = (edge) { v, head[u] }; head[u] = cnt; }
 
void dfs1(int u, int fa)
{
    f[u][0] = fa;
    dep[u] = dep[fa] + 1; sz[u] = 1; 
    for(int i = 1; i <= 20; i++)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for(int i = head[u]; i; i = e[i].next)
    {
        int v = e[i].to; if(v == fa) continue;
        dfs1(v, u);
        sz[u] += sz[v]; if(sz[son[u]] < sz[v]) son[u] = v; 
    }
}
 
void dfs2(int x, int y)
{
    top[pre[dfn[x] = ++cnt] = x] = y;
    if(!son[x]) return; dfs2(son[x], y);
    for(int i = head[x]; i; i = e[i].next) if(e[i].to != son[x] && e[i].to != f[x][0]) dfs2(e[i].to, e[i].to); 
}
 
void build(int p, int l, int r)
{
    if(l == r) return (void) (t[p].sum = w[pre[l]], t[p].tag = 0);
    int mid = (l + r) >> 1;
    build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
    t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum; 
}
 
void pushdown(int p, int l, int r)
{
    if(t[p].tag)
    {
        int mid = (l + r) >> 1; 
        t[p << 1].tag += t[p].tag; t[p << 1].sum += 1ll * t[p].tag * (mid - l + 1);
        t[p << 1 | 1].tag += t[p].tag; t[p << 1 | 1].sum += 1ll * t[p].tag * (r - mid);
        t[p].tag = 0; 
    }
}
 
void modify(int p, int l, int r, int ql, int qr, int k)
{
    if(ql <= l && r <= qr) return (void) (t[p].sum += 1ll * k * (r - l + 1), t[p].tag += k); 
    pushdown(p, l, r);
    int mid = (l + r) >> 1;
    if(ql <= mid) modify(p << 1, l, mid, ql, qr, k);
    if(mid < qr) modify(p << 1 | 1, mid + 1, r, ql, qr, k);
    t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum; 
}
 
int LCA(int x, int y)
{
    while(top[x] != top[y])
    {
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = f[top[x]][0]; 
    }
    return dep[x] < dep[y] ? x : y; 
}
 
int finds(int u, int v)
{
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 20; i >= 0; i--)
        if(dep[f[u][i]] > dep[v]) u = f[u][i];
    return u; 
}
 
long long query(int p, int l, int r, int ql, int qr)
{
    if(ql <= l && r <= qr) return t[p].sum; 
    pushdown(p, l, r); 
    int mid = (l + r) >> 1; long long ans = 0; 
    if(ql <= mid) ans += query(p << 1, l, mid, ql, qr); 
    if(mid < qr) ans += query(p << 1 | 1, mid + 1, r, ql, qr);
    t[p].sum = t[p << 1].sum + t[p << 1 | 1].sum; 
    return ans; 
}
 
int main()
{
/*  freopen("A.in", "r", stdin);
    freopen("A.out", "w", stdout); 
*/  n = read(); m = read();
    for(int i = 1; i <= n; i++) w[i] = read();
    for(int i = 1; i < n; i++)
    {
        int u = read(), v = read();
        adde(u, v); adde(v, u); 
    }
    cnt = 0; dfs1(1, 0); dfs2(1, 1);
    build(1, 1, n); 
    for(int i = 1; i <= m; i++)
    {
        int opt = read(); 
        if(opt == 1) rt = read();
        if(opt == 2)
        {
            int u = read(), v = read(), x = read(), lca = LCA(u, v); 
            if(LCA(lca, rt) == lca)
            {
                int lcau = LCA(rt, u), lcav = LCA(rt, v); 
                lcau = dep[lcau] < dep[lcav] ? lcav : lcau;
                modify(1, 1, n, 1, n, x); 
                if(rt != lcau)
                {
                    int s = finds(rt, lcau);
                    modify(1, 1, n, dfn[s], dfn[s] + sz[s] - 1, -x); 
                }
            }
            else modify(1, 1, n, dfn[lca], dfn[lca] + sz[lca] - 1, x); 
        }
        if(opt == 3)
        {
            int u = read(); 
            if(u == rt) printf("%I64d\n", query(1, 1, n, 1, n)); 
            else
            {
                int lca = LCA(u, rt); 
                if(lca == u)
                {
                    int s = finds(rt, u); 
                    printf("%I64d\n", query(1, 1, n, 1, n) - query(1, 1, n, dfn[s], dfn[s] + sz[s] - 1)); 
                }
                else printf("%I64d\n", query(1, 1, n, dfn[u], dfn[u] + sz[u] - 1)); 
            }
        }
    }
    return 0;
} 

猜你喜欢

转载自www.cnblogs.com/ztlztl/p/11184646.html