【AT3611】Tree MST

题目

这个题的输入首先就是一棵树,我们考虑一下点分

我们对于每一个分治重心考虑一下跨过这个分治重心的连边情况

就是把当前分治区域内所有的点向距离分治重心最近的点连边

考虑一下这个算法的正确性,如果我们已经对一个联通块内部形成了一个\(mst\),我们需要把这个联通块和另外一个联通块合并

如果这个新的联通块出现会使得原来联通块的\(mst\)改变,那么新出现的边也只会是原来联通块的点和新联通块到这个点距离最近的点之间的边,而这些最近的点又都是一个,所以我们就可以大大简化连边数量了

所以这个点分的过程就相当于合并\(mst\)的过程

我们点分之后发现我们连了大概\(nlogn\)条边,于是再跑一个kruskal就好了,复杂度\(O(nlog^2n)\)

代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define re register
#define LL long long
#define max(a,b) ((a)>(b)?(a):(b))
#define min(a,b) ((a)<(b)?(a):(b))
inline int read() {
    char c=getchar();int x=0;while(c<'0'||c>'9') c=getchar();
    while(c>='0'&&c<='9') x=(x<<3)+(x<<1)+c-48,c=getchar();return x;
}
const int maxn=2e5+5;
struct E{int v,nxt,w;}e[maxn<<1];
struct Edge{int a,b;LL c;}E[maxn*55];
int sum[maxn],vis[maxn],head[maxn],mx[maxn],a[maxn],fa[maxn],sz[maxn];
int n,num,m,dx,S,rt;LL dw,ans,pre[maxn];
inline void add(int x,int y,int z) {
    e[++num].v=y;e[num].nxt=head[x];head[x]=num;e[num].w=z;
}
void getroot(int x,int fa) {
    sum[x]=1,mx[x]=0;
    for(re int i=head[x];i;i=e[i].nxt) {
        if(vis[e[i].v]||e[i].v==fa) continue;
        getroot(e[i].v,x);sum[x]+=sum[e[i].v];
        mx[x]=max(mx[x],sum[e[i].v]);
    }
    mx[x]=max(mx[x],S-sum[x]);
    if(mx[x]<mx[rt]) rt=x;
}
void getdis(int x,int fa) {
    E[++m]=(Edge){dx,x,pre[x]+a[x]+dw};
    for(re int i=head[x];i;i=e[i].nxt) {
        if(vis[e[i].v]||e[i].v==fa) continue;
        getdis(e[i].v,x);
    }
}
void chk(int x,int fa) {
    if(pre[x]+a[x]<dw) dw=pre[x]+a[x],dx=x;
    for(re int i=head[x];i;i=e[i].nxt) {
        if(vis[e[i].v]||e[i].v==fa) continue;
        pre[e[i].v]=pre[x]+e[i].w;chk(e[i].v,x);
    }
}
void dfs(int x) {
    dx=x,dw=a[x];vis[x]=1;pre[x]=0,chk(x,0),getdis(x,0);
    for(re int i=head[x];i;i=e[i].nxt) {
        if(vis[e[i].v]) continue;
        S=sum[e[i].v],rt=0,getroot(e[i].v,0),dfs(rt);
    }
}
inline int cmp(Edge A,Edge B) {return A.c<B.c;}
inline int find(int x) {return x==fa[x]?x:fa[x]=find(fa[x]);}
inline int merge(int x,int y) {
    int xx=find(x),yy=find(y);
    if(xx==yy) return 0;
    if(sz[xx]<sz[yy]) fa[xx]=yy,sz[yy]+=sz[xx];
        else fa[yy]=xx,sz[xx]+=sz[yy];
    return 1;
}
int main() {
    n=read();
    for(re int i=1;i<=n;i++) a[i]=read();
    for(re int x,y,z,i=1;i<n;i++) 
        x=read(),y=read(),z=read(),add(x,y,z),add(y,x,z); 
    mx[0]=n+1,S=n,rt=0,getroot(1,0),dfs(rt);
    std::sort(E+1,E+m+1,cmp);
    for(re int i=1;i<=n;i++) sz[i]=1,fa[i]=i;
    for(re int i=1;i<=m;i++) if(merge(E[i].a,E[i].b)) ans+=E[i].c;
    std::cout<<ans;
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/asuldb/p/10977498.html
MST
今日推荐