dsu on tree(树上启发式合并)

简介

对于一颗静态树,O(nlogn)时间内处理子树的统计问题。是一种优雅的暴力。

算法思想

很显然,朴素做法下,对于每颗子树对其进行统计的时间复杂度是平方级别的。考虑对树进行一个重链剖分。虽然都基于重链剖分,但不同于树剖,我们维护的不是树链。
对于每个节点,我们先处理其轻儿子所在子树,轻子树在处理完后消除其影响。然后处理重儿子所在子树,保留其贡献。然后再暴力跑该点的轻子树,统计该点子树的最终答案。如果该点子树是轻子树,则消除该子树的影响,否则保留。用代码描述的话,大概是这个流程:

void dfs(int u,int fa,int hvy)
{
    for(v :G[u])//处理轻子树
    {
        if(v==f||v==son[u])
            continue;
        dfs(v,u,0);
    }
    if(son[u])//处理重子树
        dfs(son[u],u,1);
    calc(u,fa,1);//暴力统计轻子树对该点答案的贡献
    ans[u]=res;
    if(!hvy)
        calc(u,fa,-1);//若点u所在子树是轻子树,则逆着原来统计的操作来消除其影响。
}

以上体现大概思想,但遇到具体题目可能有很多细节需要思考。

复杂度分析

这个可能不能很容易的明白其为何高效,如何达到O(nlogn)。因此我们考虑每个节点对时间复杂度的贡献。如果真的明白上述的算法流程,可以知道我们执行暴力统计的都是对轻边所连的子树,因此每个点被遍历到的次数与它往上到根的轻边数量有关。而任一点到根的路径上,轻边的数量不会超过logn。因此每个点最多被遍历logn次。这样想应该好理解很多。

举例

Lomsat gelral
这是一道比较经典的入门题,有兴趣的可以练手,感受一下算法的思想,再做下一题。在此不给出代码。
下面稍微讲一下D. Arpa’s letter-marked tree and Mehrdad’s Dokhtar-kosh paths
感觉这道题还是挺难的,要考虑不少细节。
题意大概就是每条边有一个字符(a-v),求每颗子树下最长的一条简单路径,其上的字符可重组成回文串。显然就是要至多只有一个字符出现奇数次。
我们把每种字符看作二进制上的一个位,即2的幂。则满足条件的简单路径,其边权异或结果必须为0或2的幂。
因此用到dp和dsu on tree的思想。a[i]表示点i到根的路径异或值,dp[i]表示a[x]=i的点中,深度最大的x的深度。
对于一颗以u为根的子树,它的答案路径(该路径默认包含u,因此可能不是最终答案)可能是1.u到其子树中某点的简单路径;2.u的两颗不同子树中的两点间的路径。前者直接判断来更新答案;对于后者两颗子树间的情况,需要不断更新每个异或值下的最大深度,方便对于跑到的点可以知道此时与它满足条件的另一点的最大深度,从而得知路径长来更新答案。然后若该子树为重子树,则保留dp信息,否则重置。
附上代码

#include<bits/stdc++.h>
#define dd(x) cout<<#x<<" = "<<x<<" "
#define de(x) cout<<#x<<" = "<<x<<"\n"
#define sz(x) int(x.size())
#define All(x) x.begin(),x.end()
#define pb push_back
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef long double ld;
typedef pair<int,int> P;
typedef priority_queue<int> BQ;
typedef priority_queue<int,vector<int>,greater<int> > SQ;
const int maxn=5e5+10,mod=1e9+7,INF=0x3f3f3f3f;
int a[maxn],dp[maxn*10],sz[maxn],d[maxn],son[maxn],ans[maxn];
vector<int> G[maxn];
void dfs1(int u,int fa)
{
    sz[u]=1;
    d[u]=d[fa]+1;
    a[u]^=a[fa];
    for (auto& v:G[u])
    {
        dfs1(v,u);
        sz[u]+=sz[v];
        if (sz[v]>sz[son[u]])
            son[u]=v;
    }
}
int mx;
bool check(int x,int y)
{
    int t=x^y,cnt=0;
    for (int i=0;i<='v'-'a';++i)
        cnt+=(t>>i)&1;
    return cnt<=1;
}
void cal(int rt,int u)
{
    if (check(a[u],a[rt]))
        mx=max(mx,d[u]-d[rt]);
    mx=max(mx,dp[a[u]]+d[u]-2*d[rt]);
    for (int i=0;i<='v'-'a';++i)
        mx=max(mx,dp[a[u]^(1<<i)]+d[u]-2*d[rt]);
    for (auto& v:G[u])
        cal(rt,v);
}
void upd(int u,int ty)
{
    if (ty)
        dp[a[u]]=max(dp[a[u]],d[u]);
    else
        dp[a[u]]=-INF;
    for (auto& v:G[u])
        upd(v,ty);
}
void dfs2(int u,int hvy)
{
    for (auto&v :G[u])
    {
        if (v==son[u])
            continue;
        dfs2(v,0);
    }
    if (son[u])
        dfs2(son[u],1);
    mx=0;
    mx=max(mx,dp[a[u]]-d[u]);
    for (int i=0;i<='v'-'a';++i)
        mx=max(mx,dp[a[u]^(1<<i)]-d[u]);
    for (auto& v:G[u])
    {
        if (v==son[u])
            continue;
        cal(u,v);
        upd(v,1);
    }
    ans[u]=mx;
    if (hvy)
        dp[a[u]]=max(dp[a[u]],d[u]);
    else
    {
        for (auto& v:G[u])
            upd(v,0);
        dp[a[u]]=-INF;
    }
}
void solve(int u)
{
    for (auto& v:G[u])
    {
        solve(v);
        ans[u]=max(ans[u],ans[v]);
    }
}
int main()
{
    int n;
    cin>>n;
    char c[2];
    for (int i=2;i<=n;++i)
    {
        int f;
        scanf("%d%s",&f,c);
        G[f].pb(i);
        a[i]=1<<(c[0]-'a');
    }
    for (int i=1;i<maxn*10;++i)
        dp[i]=-INF;
    dfs1(1,0);
    dfs2(1,1);
    solve(1);
    for (int i=1;i<=n;++i)
        printf("%d ",ans[i]);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/orangee/p/10463899.html