【启发式合并】(dsu on tree)讲解

【启发式合并】(dsu on tree)讲解+例题

超级好的讲解(来自cf)


一、启发式合并的作用:

我们可以在O(nlogn) 的时间内,回答下列形式的所有询问:
节点v的子树中有多少节点有某种特性。

例如:
给一棵树,每个节点有一种颜色,询问:节点v的子树中,有多少节点的颜色是color c?

二、解决方法

2.1 准备工作:计算每个节点的子树大小
int sz[maxn];
void getsz(int v, int p)
{
    sz[v] = 1;  // every vertex has itself in its subtree
    for(auto u : g[v])
        if(u != p)
        {
            getsz(u, v);
            sz[v] += sz[u]; // add size of child u to its parent(v)
        }
}

2.2 普通解法(O(N ^ 2)
int cnt[maxn];
void add(int v, int p, int x)
{
    cnt[ col[v] ] += x;
    for(auto u: g[v])
        if(u != p)
            add(u, v, x)
}
        
void dfs(int v, int p)
{
    add(v, p, 1);
    //now cnt[c] is the number of vertices in subtree of vertex v that has color c. You can answer the queries easily.
    add(v, p, -1);
    for(auto u : g[v])
        if(u != p)
            dfs(u, v);
}

个人理解:要求得每一棵子树的值,那么就需要对每一棵子树进行遍历,由于只有一个cnt记录空间,因此每次遍历完后,需要将cnt数组清空,然后才能进行下一个兄弟节点的遍历,否则会将从cnt里的数据搞混淆,得出错误的结果。
这样一个节点会被重复遍历多次,复杂度很高。


2.3 easy to code but O(n log^2 n)
void dfs(int v, int p)
{
    int mx = -1, bigChild = -1;
    for(auto u : g[v])
        if(u != p)
        {
            dfs(u, v);
            if(sz[u] > mx)
                mx = sz[u], bigChild = u;
        }
    if(bigChild != -1)
        cnt[v] = cnt[bigChild];
    else
        cnt[v] = new map<int, int> ();
    (*cnt[v])[ col[v] ] ++;
    for(auto u : g[v])
        if(u != p && u != bigChild)
        {
            for(auto x : *cnt[u])
                (*cnt[v])[x.first] += x.second;
        }
    //now (*cnt[v])[c] is the number of vertices in subtree of vertex v that has color c. You can answer the queries easily.

}
2.4 easy to code and O(nlogn)
vector<int> *vec[maxn];
int cnt[maxn];
void dfs(int v, int p, bool keep)
{
    int mx = -1, bigChild = -1;
    for(auto u : g[v])
        if(u != p && sz[u] > mx)
            mx = sz[u], bigChild = u;
    for(auto u : g[v])
        if(u != p && u != bigChild)
            dfs(u, v, 0);
    if(bigChild != -1)
        dfs(bigChild, v, 1), vec[v] = vec[bigChild];
    else
        vec[v] = new vector<int> ();
    vec[v]->push_back(v);
    cnt[ col[v] ]++;
    for(auto u : g[v])
        if(u != p && u != bigChild)
            for(auto x : *vec[u])
            {
                cnt[ col[x] ]++;
                vec[v] -> push_back(x);
            }
    //now (*cnt[v])[c] is the number of vertices in subtree of vertex v that has color c. You can answer the queries easily.
    // note that in this step *vec[v] contains all of the subtree of vertex v.
    if(keep == 0)
        for(auto u : *vec[v])
            cnt[ col[u] ]--;
}

2.5 heavy-light decomposition style O(nlogn) 【重点】

int cnt[maxn];
bool big[maxn];
void add(int v, int p, int x)
{
    cnt[ col[v] ] += x;
    for(auto u: g[v])
        if(u != p && !big[u])
            add(u, v, x)
}

void dfs(int v, int p, bool keep)
{
    int mx = -1, bigChild = -1;
    for(auto u : g[v])
        if(u != p && sz[u] > mx)
            mx = sz[u], bigChild = u;
    for(auto u : g[v])
        if(u != p && u != bigChild)
            dfs(u, v, 0);  // run a dfs on small childs and clear them from cnt
    if(bigChild != -1)
        dfs(bigChild, v, 1), big[bigChild] = 1;  // bigChild marked as big and not cleared from cnt
    add(v, p, 1);
    //now cnt[c] is the number of vertices in subtree of vertex v that has color c. You can answer the queries easily.
    if(bigChild != -1)
        big[bigChild] = 0;
    if(keep == 0)
        add(v, p, -1);
}

个人理解

  • (原始解法)要求得每一棵子树的值,那么就需要对每一棵子树进行遍历,由于只有一个cnt记录空间,因此每次遍历完后,需要将cnt数组清空,然后才能进行下一个兄弟节点的遍历,否则会将从cnt里的数据搞混淆,得出错误的结果。
  • 我们发现其实这里面有些重复的步骤可以省略。比如说,在遍历一个节点u时,因为是dfs,所以一定是先将u的所有子节点都dfs一遍后,然后才会返回到u,对u进行求解。而对u进行求解时,要对子节点求和再加上他自己,还是需要遍历它的所有子节点,那么既然上面已经遍历一遍子节点,那么刚刚遍历的子节点信息不删除,直接在这里用,不就节省了一遍重复遍历的过程吗?
  • 但是不能保留下所有的子节点信息。因为我们进行dfs时,是公用cnt这个数组的,如果不将之前的信息清除掉,就会造成混乱。(比如节点1有三个子节点2、3、4,而4有两个子节点5、6。现在对1进行遍历,那么首先要对子节点遍历,也就是进行后序遍历:2、3、5、6、4、1。如果对2进行遍历后,不清空cnt,那么接着对3进行遍历的时候,就会把2和3的信息混合,导致3的信息出错。)
  • 但是我们可以保留最后一个遍历的子节点的信息,然后再把前面清空掉的信息,再遍历一遍,这样相当于少遍历了一遍。(还是上面那个例子,我们遍历完4的信息后,马上就要返回1了,而计算1的时候,还是需要对2、3、4计算一遍,因为1 = 1 + 2 + 3+ 4。那么从4返回1的时候,4的信息就不用清除了,此时cnt中装有4的信息,即cnt[1] = cnt[4]。然后再把2、3和1自己的信息加上,就完成了对1的遍历。这个过程中可以看到,和原始做饭做法相比,少计算了一次4的值)。
  • 既然可以少算一个子节点的值,那么选择哪个节点作为这个最后一个即可以被少算的这个节点,是至关重要的。显然,我们应该选择规模最大的一个子节点(也叫作重儿子),这样就可以少花点时间。
  • 于是可以得出算法:每一次先对轻儿子进行遍历,求出轻儿子的ans,在这个过程中,每一次都要清空cnt。然后计算重儿子的ans,计算完后,不对cnt清空,接着求父节点自己的,再把轻儿子的值再遍历一遍,累加到cnt中。这样就完成了对一整棵树所有节点ans的遍历求解。
  • 和原始做法相比,我们节约了对重儿子的重复计算。

自己修改过的板子

//CodeForces-600E

#include <iostream>
#include <stdio.h>
#include <string.h>
#include <string>
#include <math.h>
#include <algorithm>
#include <vector>

using namespace std;

const int maxn = 1e5 + 100;

int col[maxn], sz[maxn], son[maxn], cnt[maxn];
int n, big, maxx;
long long ans[maxn], sum;
vector<int> G[maxn];

void get_son(int u, int p)
{
    sz[u] = 1;
    int len = G[u].size();
    for(int i = 0; i < len; i++)
    {
        int v = G[u][i];
        if(v != p)
        {
            get_son(v, u);
            sz[u] += sz[v];
            if(sz[v] > sz[son[u]])
                son[u] = v;
        }
    }
}

void add(int u, int p, int x)
{
    cnt[ col[u] ] += x;
    if(cnt[ col[u] ] > maxx)
    {
        sum = col[u];
        maxx = cnt[ col[u] ];
    }
    else if(cnt[ col[u] ] == maxx)
    {
        sum += col[u];
    }

    int len = G[u].size();
    for(int i = 0; i < len; i++)
    {
        int v = G[u][i];
        if(v != p && v != big)
            add(v, u, x);
    }
}

void dfs(int u, int p, bool keep)
{
    int len = G[u].size();
    for(int i = 0; i < len; i++)
    {
        int v = G[u][i];
        if(v != p && v != son[u])
            dfs(v, u, 0);
    }
    if(son[u])
    {
        dfs(son[u], u, 1);
        big = son[u];
    }
    add(u, p, 1);
    big = 0;
    ans[u] = sum;
    if(!keep)
    {
        add(u, p, -1);
        maxx = 0;
        sum = 0;
    }
}


int main()
{
    while(~scanf("%d", &n))
    {
        big = 0;
        maxx = 0;
        sum = 0;
        for(int i = 1; i <= n; i++)
        {
            G[i].clear();
            sz[i] = son[i] = cnt[i] = ans[i] = 0;
        }
        for(int i = 1; i <= n; i++)
            scanf("%d", &col[i]);
        int u, v;
        for(int i = 0; i < n - 1; i++)
        {
            scanf("%d%d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        get_son(1, 0);
        dfs(1, 0, 1);
        for(int i = 1; i <= n; i++)
        {
            if(i != 1)
                printf(" ");
            printf("%lld", ans[i]);
        }
        printf("\n");
    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/Floraqiu/article/details/86560162