关于长链剖分

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DT_Kang/article/details/82495751

看这样一个题(dsu on the tree):
给你一棵树,每个节点有一种颜色,问你每个子树x的颜色数最多的那种颜色,如果颜色数相同,那么种类数相加。
考虑最暴力的暴力,对于每个点遍历它的子树,统计答案,然后再撤销。但是这样太傻了,每个点显然可以继承一个儿子的信息,我们选择继承它的重儿子的信息,只 dfs 轻儿子。这样对于每个点,会被 dfs 它到根之间轻边数量次。所以复杂度是 O ( n log n ) O(n\log n)

如果需要维护的是关于深度的信息呢?我们引入长链剖分。长链剖分,类似于重链剖分,我们定义每个点的 len 为从这个点出发向下的最长链的长度,把每个点的“长儿子”定义为所有儿子里 len 最长的点。

考虑维护深度信息,发现这时候我们继承重儿子显得很浪费,我们选择继承长儿子,然后合并短儿子的深度。发现每条长链只会在链顶被遍历一遍,而长链互不相交,因此复杂度是优秀的 O ( n ) O(n)

长链剖分还有一个应用是 O ( 1 ) O(1) k k 级祖先,在这里就不啰嗦了。

例题:[WC2010]重建计划
二分答案以后,就是找边数在 [ L , U ] [L,U] 的最长链。考虑暴力的 dp,设 f [ i ] [ j ] f[i][j] 表示 i i 的子树中深度为 j j 的点与 i i 的最长距离。这是以深度为下标的信息,我们尝试用长链剖分去优化。一条链的 dfs 序是连续的一段,用 f [ d f n i + j ] f[dfn_i+j] 表示 f [ i ] [ j ] f[i][j] ,我们发现继承长儿子信息的时候深度恰好“后移”了一位。我们用线段树维护这个数组,然后合并短儿子的时候顺便统计答案。

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
using namespace std;
struct edge{
    int to,next,w;
}ed[2000010];
int sz,head[1000010],pos[1000010],len[1000010],son[1000010],tim,w[1000010],n,L,U;
double Max[4000010],x,ad[4000010];
double z,f[1000010],g[1000010];
void add_edge(int from,int to,int w)
{
    ed[++sz].to=to;
    ed[sz].next=head[from];
    ed[sz].w=w;
    head[from]=sz;
}
void push_down(int root,int nl,int nr)
{
    if(ad[root])
    {
        Max[root<<1]+=ad[root];
        Max[root<<1|1]+=ad[root];
        ad[root<<1]+=ad[root];
        ad[root<<1|1]+=ad[root];
        ad[root]=0;
    }
}
void update(int root,int l,int r,int x,double k)
{
    if(l==r) 
    {
        Max[root]=max(Max[root],k);
        return;
    }
    int mid=l+r>>1;
    push_down(root,mid-l+1,r-mid);
    if(x<=mid) update(root<<1,l,mid,x,k);
    else update(root<<1|1,mid+1,r,x,k);
    Max[root]=max(Max[root<<1],Max[root<<1|1]);
}
double query(int root,int l,int r,int x,int y)
{
    if(x<=l&&y>=r) return Max[root];
    int mid=l+r>>1;
    push_down(root,mid-l+1,r-mid);
    if(y<=mid) return query(root<<1,l,mid,x,y);
    if(x>mid) return query(root<<1|1,mid+1,r,x,y);
    return max(query(root<<1,l,mid,x,y),query(root<<1|1,mid+1,r,x,y)); 
}
void add(int root,int l,int r,int x,int y,double k)
{
    double tmp=Max[root];
    if(x<=l&&y>=r)
    {
        Max[root]+=k;
        ad[root]+=k;
        return;
    }
    int mid=l+r>>1;
    if(x<=mid) add(root<<1,l,mid,x,y,k);
    if(y>mid) add(root<<1|1,mid+1,r,x,y,k);
    Max[root]=max(Max[root<<1],Max[root<<1|1]);
}
void dfs2(int u,int ff)
{
    if(!pos[u]) pos[u]=++tim;
    int pu=pos[u];
    if(son[u])
    {
        dfs2(son[u],u);
        add(1,1,n,pu+1,pu+len[u]-1,w[u]-x);
    }
    for(int i=head[u];i;i=ed[i].next)
    {
        int v=ed[i].to;
        if(v==son[u]||v==ff) continue;
        dfs2(v,u);
        int pv=pos[v];
        for(int j=1;j<=len[v];j++)
        {
            if(j+len[u]-1>=L&&j<=U)
            {
                double tmp=query(1,1,n,pu+max(1,L-j),pu+min(len[u]-1,U-j));
                z=max(z,tmp+ed[i].w-x+query(1,1,n,pv+j-1,pv+j-1));
            }
        }
        for(int j=1;j<=len[v];j++)
        {
            double tmp=query(1,1,n,pv+j-1,pv+j-1);
            if(tmp+ed[i].w-x>query(1,1,n,pu+j,pu+j))
            {
                update(1,1,n,pu+j,tmp+ed[i].w-x);
            }
        }
    }
    if(len[u]-1>=L) z=max(z,query(1,1,n,pu+L,pu+min(U,len[u]-1)));
}
void clear(int root,int l,int r)
{
    Max[root]=0;
    ad[root]=0;
    if(l==r) return;
    int mid=l+r>>1;
    clear(root<<1,l,mid);
    clear(root<<1|1,mid+1,r);
}
bool check(double h)
{
    clear(1,1,n);x=h;
    z=-1e18;dfs2(1,1);
    if(z>=-1e-7) return true;
    return false;
}
void dfs1(int u,int ff)
{
    len[u]=-1;
    for(int i=head[u];i;i=ed[i].next)
    {
        int v=ed[i].to;
        if(v==ff) continue;
        dfs1(v,u);
        if(len[v]>len[u]) len[u]=len[v],son[u]=v,w[u]=ed[i].w;
    }
    len[u]++;
}
int main()
{
    double l=0,r=0;
    scanf("%d%d%d",&n,&L,&U);
    for(int i=1;i<n;i++)
    {
        int u,v,w;
        scanf("%d%d%d",&u,&v,&w);
        add_edge(u,v,w);
        add_edge(v,u,w);
        r+=w;
    }
    dfs1(1,1);
    for(int i=1;i<=n;i++) len[i]++;
    double ans=0;
    r=1e6;
    while(r-l>1e-5)
    {
        double mid=(l+r)/2;
        if(check(mid)) l=mid;
        else r=mid;
    }
    printf("%.3lf\n",l);
    return 0;
}

猜你喜欢

转载自blog.csdn.net/DT_Kang/article/details/82495751