「数据结构」 树上启发式合并

dsu on tree 用来解决树上问题。可以在 \(O(nlgon)\) 中完成对静态的子树统计。但是,不支持修改,只能对子树统计,不能链上统计。

我们来看一个问题。有一棵树,每个点有一个权值。求这棵树的每一棵子树的众数权值之和,如果有多个众数那么都要统计。先考虑 \(O(n^2)\) 的暴力,对于每一棵子树,遍历这棵子树的所有点,用一个桶记录每一个数出现的次数,统计一下众数之和,然后清空桶消除影响。下面是代码。

void add(int x, int fa, int val) {
    cnt[col[x]] += val;
    if(cnt[col[x]] > mx) mx = cnt[col[x]], sum = col[x];
    else if(cnt[col[x]] == mx) sum += col[x];
    for(int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if (y != fa) add(y, x, val);
    }
}
void dfs(int x, int fa) {
    for(int i = 0; i < G[x].size(); i++) {
        int y = G[x][i];
        if(y != fa) dfs(y, x);
    }
    add(x, fa, 1); ans[x] = sum;
    add(x, fa, -1), sum = 0, mx = 0;
}

可以发现,最后一个搜到的儿子是没有必要消除影响的,消除影响占用时间的儿子是重儿子。那么可以在暴力上加了一个不消除重儿子影响的优化。这是 dsu on tree 的核心思想,dsu on tree 的流程如下,有先后顺序。

  1. 遍历每一个节点
  2. 递归解决所有的轻儿子,同时消除递归产生的影响
  3. 递归重儿子,不消除递归的影响
  4. 暴力统计所有轻儿子对答案的影响
  5. 更新该节点的答案
  6. 暴力删除所有轻儿子对答案的影响

只加了一个不消除重儿子影响的优化。其他都是暴力,好像还是 \(O(n)\) 的,其实不然。因为一个节点到根的路径上重链和轻链个数不会超过 \(logn\) 条,只有 dfs 到轻边时,才会将轻儿子的子树中合并到上一级的重链,那么每一个点最多向上合并 \(logn\) 次,整体复杂度 \(O(nlogn)\)

学会了 dsu on tree 后,我们可以搞定上面的问题了,那是 CodeForeces600E Lomsat gelral。

#include <bits/stdc++.h>
using namespace std;
#define re register
#define F first
#define S second
#define mp make_pair
#define lson (p << 1)
#define rson (p << 1 | 1)
typedef long long ll;
typedef pair<int, int> P;
const int N = 5e5 + 5, M = 5e5 + 5;
const int INF = 0x3f3f3f3f;
inline int read() {
    int X = 0,w = 0; char ch = 0;
    while(!isdigit(ch)) {w |= ch == '-';ch = getchar();}
    while(isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48),ch = getchar();
    return w ? -X : X;
}
inline void write(int x){
    if(x < 0) putchar('-'), x = -x;
    if(x > 9) write(x / 10);
    putchar(x % 10 + '0');
}
int n, val[N], cnt[N], mx;
ll sum, ans[N];
struct edge{
    int to, nxt;
}e[M];
int head[N], tot;
void addedge(int x, int y){
    e[++tot].to = y, e[tot].nxt = head[x], head[x] = tot;
}
int sz[N], son[N];
void dfs1(int x, int fa){
    sz[x] = 1;
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa){
            dfs1(y, x); sz[x] += sz[y];
            if (sz[y] > sz[son[x]]) son[x] = y;
        }
    }
}
bool vis[N];
void add(int x, int fa, int k){
    cnt[val[x]] += k;
    if (k > 0 && cnt[val[x]] > mx) sum = val[x], mx = cnt[val[x]];
    else if (k > 0 && cnt[val[x]] == mx) sum += val[x];
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa && !vis[y]) add(y, x, k);
    }
}
void dfs2(int x, int fa, int flg){
    for (int i = head[x]; i; i = e[i].nxt){
        int y = e[i].to;
        if (y != fa && y != son[x]) dfs2(y, x, 0);
    }
    if (son[x]) dfs2(son[x], x, 1), vis[son[x]] = 1;
    add(x, fa, 1); ans[x] = sum;
    if (son[x]) vis[son[x]] = 0; 
    if (!flg) add(x, fa, -1), sum = mx = 0;
}
int main() {
    n = read();
    for (int i = 1; i <= n; i++) val[i] = read();
    for (int i = 1; i < n; i++){
        int x = read(), y = read();
        addedge(x, y); addedge(y, x);
    }
    dfs1(1, 0); dfs2(1, 0, 0);
    for (int i = 1; i <= n; i++) printf("%lld ", ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/lyfoi/p/11621976.html