某个树上数据结构 - splay - 启发式合并

题意:给你一颗有边权、点有颜色的有根树,根是1,对每个点求,仅考虑其子树,哪种颜色的点两两距离之和最大,多解输出编号最小。
题解:省选前敲一波数据结构。splay用启发式合并即可。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<utility>
#define fir first
#define sec second
#define N 3000010
#define lint long long
#define INF 1000000
#define inf -1000000
#define gc getchar()
#define mp make_pair
#define debug(x) cerr<<#x<<"="<<x
#define sp <<" "
#define ln <<endl
using namespace std;
struct edges{
    int to,pre,wgt;
}e[N];lint s[N],cnt[N],ans[N],tag[N];int h[N],etop;
int val[N],ch[N][2],fa[N],sz[N],node_cnt,a[N];
pair<lint,int> mx[N];
inline int add_edge(int u,int v,int w)
{   return e[++etop].to=v,e[etop].pre=h[u],e[etop].wgt=w,h[u]=etop; }
struct Splay{
    int rt,infp,INFp;
    inline int new_node(int v) { return val[++node_cnt]=v,node_cnt; }
    inline int init()
    {   return infp=new_node(inf),INFp=new_node(INF),rt=infp,push_up(INFp),setc(infp,INFp,1);   }
    inline int size() { return sz[rt]; }
    inline int push_up(int x)
    {   return mx[x]=max(max(mx[ch[x][0]],mx[ch[x][1]]),mp(cnt[x],val[x])),sz[x]=sz[ch[x][0]]+sz[ch[x][1]]+1;   }
    inline int setc(int x,int y,int z)
    {
        if(!x) return fa[rt=y]=0;
        ch[x][z]=y;if(y) fa[y]=x;
        return push_up(x);
    }
    inline int gw(int x) { return ch[fa[x]][1]==x; }
    inline int rotate(int x)
    {
        int y=fa[x],z=fa[y],a=gw(x),b=gw(y),c=ch[x][a^1];
        return setc(y,c,a),setc(x,y,a^1),setc(z,x,b);
    }
    inline int push_down(int x)
    {
        if(ch[x][0]) tag[ch[x][0]]+=tag[x];
        if(ch[x][1]) tag[ch[x][1]]+=tag[x];
        return s[x]+=(lint)tag[x]*cnt[x],tag[x]=0,0;
    }
    int all_push_down(int x)
    {
        if(fa[x]) all_push_down(fa[x]);
        if(tag[x]) push_down(x);return 0;
    }
    inline int splay(int x,int tar=0)
    {
        all_push_down(x);
        for(;fa[x]^tar;rotate(x))
            if(fa[fa[x]]^tar) rotate((gw(x)^gw(fa[x]))?x:fa[x]);
        return 0;
    }
    inline int insert(int v,lint c,lint ss)//update cnt[],s[]
    {
        int x=rt,las=0;
        while(x)
        {
            las=x;
            if(v<val[x]) x=ch[x][0];
            else if(v>val[x]) x=ch[x][1];
            else break;
        }
        if(!x) setc(las,x=new_node(v),v>val[las]);
        return splay(x),push_down(x),cnt[x]+=c,s[x]+=ss,push_up(x),x;
    }
    inline int merge_from(int x)//merge f
    {
        if(ch[x][0]) merge_from(ch[x][0]);
        if(val[x]>inf&&val[x]<INF) insert(val[x],cnt[x],0ll);
        if(ch[x][1]) merge_from(ch[x][1]);
        return 0;
    }
    inline int merge_from(int x,Splay &y)//merge g
    {
        if(tag[x]) push_down(x);
        if(ch[x][0]) merge_from(ch[x][0],y);
        if(val[x]>inf&&val[x]<INF)
        {
            int z=insert(val[x],cnt[x],s[x]);
            y.insert(val[x],s[x]*(cnt[z]-cnt[x])+(s[z]-s[x])*cnt[x],0ll);
        }
        if(ch[x][1]) merge_from(ch[x][1],y);
        return 0;
    }
    inline int swap(Splay &x) { return std::swap(rt,x.rt),0; }
}f[N],g[N];
inline int show(int x)
{
    debug(x)sp,debug(val[x])sp,debug(cnt[x])sp,debug(s[x])sp,debug(tag[x])sp,debug(ch[x][0])sp,debug(ch[x][1])sp,debug(mx[x].fir)sp,debug(mx[x].sec)ln;
    if(ch[x][0]) show(ch[x][0]);if(ch[x][1]) show(ch[x][1]);return 0;
}
int dfs(int x,int fa)
{
    f[x].init(),g[x].init(),g[x].insert(a[x],1ll,0ll);
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)^fa)
        {
            dfs(y,x),tag[g[y].rt]+=e[i].wgt;
            if(f[x].size()<f[y].size()) f[x].swap(f[y]);
            f[x].merge_from(f[y].rt);
            if(g[x].size()<g[y].size()) g[x].swap(g[y]);
            g[x].merge_from(g[y].rt,f[x]);
        }
//  debug(x)ln,show(f[x].rt),cerr ln,show(g[x].rt),cerr ln ln;
    return ans[x]=-mx[f[x].rt].sec,0;
}
inline int sec_dfs(int x,int fa)
{
    a[x]=-a[x];
    for(int i=h[x],y;i;i=e[i].pre)
        if((y=e[i].to)^fa) a[x]=min(a[x],sec_dfs(y,x));
    return a[x];
}
inline int inn()
{
    int x,ch;while((ch=gc)<'0'||ch>'9');
    x=ch^'0';while((ch=gc)>='0'&&ch<='9')
        x=(x<<1)+(x<<3)+(ch^'0');return x;
}
char ss[3000000],tt[25];int ssl,ttl;
inline int print(lint x)
{
    if(!x) ss[++ssl]='0';
    ttl=0;while(x) tt[++ttl]=(char)(x%10)+'0',x/=10;
    while(ttl--) ss[++ssl]=tt[ttl+1];ss[++ssl]='\n';
    return 0;
}
int main()
{
//  freopen("a.in","r",stdin);
//  freopen("std.out","w",stdout);
    int n=inn();
    for(int i=1,u,v,w;i<n;i++)
        u=inn(),v=inn(),w=inn(),add_edge(u,v,w),add_edge(v,u,w);
    for(int i=1;i<=n;i++) a[i]=-inn();
    dfs(1,0),sec_dfs(1,0);
//  for(int i=1;i<=n;i++) debug(i)ln,show(f[i].rt),cerr ln ln,show(g[i].rt),cerr ln ln;
    for(int i=1;i<=n;i++) print(ans[i]==-INF?a[i]:ans[i]);
    return fwrite(ss+1,sizeof(char),ssl,stdout),0;
}

猜你喜欢

转载自blog.csdn.net/mys_c_k/article/details/80200906