树上启发式合并 dsu on tree

dsu on tree 用来解决树上问题。可以在 O ( n log n ) O(n \log n) 中完成对静态的子树统计。但是,不支持修改,只能对子树统计,不能链上统计。

我们来看一个问题。有一棵树,每个点有一个权值。求这棵树的每一棵子树的众数权值之和,如果有多个众数那么都要统计。先考虑 O ( n 2 ) O(n^2) 的暴力,对于每一棵子树,遍历这棵子树的所有点,用一个桶记录每一个数出现的次数,统计一下众数之和,然后清空桶消除影响。下面是代码。

void add(int x, int fa, int val) {
    cnt[col[x]] += val;
    if(cnt[col[x]] > mx) mx = cnt[col[x]], sum = col[x];
    else if(cnt[col[x]] == mx) sum += col[x];
    for(int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if (y != fa) add(y, x, val);
    }
}
void dfs(int x, int fa) {
    for(int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if(y != fa) dfs(y, x);
    }
    add(x, fa, 1); ans[x] = sum;
    add(x, fa, -1), sum = 0, mx = 0;
}

可以发现,最后一个搜到的儿子是没有必要消除影响的,因为消除影响占用时间最多的儿子是重儿子,所以可以在暴力上加了一个不消除重儿子影响的优化。这是 dsu on tree 的核心思想,dsu on tree 的流程如下,有先后顺序。

  1. 遍历每一个节点
  2. 递归解决所有的轻儿子,同时消除递归产生的影响
  3. 递归重儿子,不消除递归的影响
  4. 暴力统计所有轻儿子对答案的影响
  5. 更新该节点的答案
  6. 暴力删除所有轻儿子对答案的影响

只加了一个不消除重儿子影响的优化。其他都是暴力,好像还是 O ( n ) O(n) 的,其实不然。因为一个节点到根的路径上重链和轻链个数不会超过 log n \log n 条,只有 dfs 到轻边时,才会将轻儿子的子树中合并到上一级的重链,那么每一个点最多向上合并 log n \log n 次,整体复杂度 O ( n log n ) O(n \log n)

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define mp make_pair
#define lson (p << 1)
#define rson (p << 1 | 1)
typedef long long ll;
typedef pair<int, int> P;
const int N = 5e5 + 5, M = 5e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
inline void write(int x){
    if(x < 0) putchar('-'), x = -x;
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
int n, val[N], cnt[N], mx;
ll sum, ans[N];
struct edge{
    int to, nxt;
}e[M];
int head[N], tot;
void addedge(int x, int y){
    e[++tot].to = y, e[tot].nxt = head[x], head[x] = tot;
}
int sz[N], son[N];
void dfs1(int x, int fa){
    sz[x] = 1;
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa){
            dfs1(y, x); sz[x] += sz[y];
            if (sz[y] > sz[son[x]]) son[x] = y;
        }
    }
}
bool vis[N];
void add(int x, int fa, int k){
    cnt[val[x]] += k;
    if (k > 0 && cnt[val[x]] > mx) sum = val[x], mx = cnt[val[x]];
    else if (k > 0 && cnt[val[x]] == mx) sum += val[x];
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa && !vis[y]) add(y, x, k);
    }
}
void dfs2(int x, int fa, int flg){
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa && y != son[x]) dfs2(y, x, 0);
    }
    if (son[x]) dfs2(son[x], x, 1), vis[son[x]] = 1;
    add(x, fa, 1); ans[x] = sum;
    if (son[x]) vis[son[x]] = 0; 
    if (!flg) add(x, fa, -1), sum = mx = 0;
}
int main() {
    n = read();
    for (int i = 1; i <= n; i++) val[i] = read();
    for (int i = 1; i < n; i++){
        int x = read(), y = read();
        addedge(x, y); addedge(y, x);
    }
    dfs1(1, 0); dfs2(1, 0, 0);
    for (int i = 1; i <= n; i++) printf("%lld ", ans[i]);
    return 0;
}
发布了28 篇原创文章 · 获赞 38 · 访问量 484

猜你喜欢

转载自blog.csdn.net/qq_39984146/article/details/104226020