Luogu P4180 [Template] Strictly sub-small spanning tree [BJWC2010] [sub-small spanning tree]

Strictly subminimum spanning tree template
Algorithm process:
first use Kruskal to find the minimum spanning tree, and then cut the minimum spanning tree tree, maintain the edge weight to the point weight, maintain the maximum value and the strict second maximum value.
Then enumerate the edges that are not selected into the minimum spanning tree, and check the longest edge on the path of the two ends of this edge on the minimum spanning tree. If the longest edge is equal to the edge weight of the enumerated edge, then select The second longest edge (skip if there is no second longest edge), then subtract the longest/second longest edge on the path from the weight of the minimum spanning tree, plus the edge weight of the currently enumerated edge
because if the enumerated edge is added , then a ring is formed, and an edge needs to be broken

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstdlib>
using namespace std;
const int N=300005;
int n,m,h[N],cnt,f[N],con,fa[N],si[N],hs[N],de[N],fr[N],id[N],rl[N],va[N],tmp;
long long ans=1e18,sum;
bool mk[N];
struct qwe
{
    int ne,to,va;
}e[N<<1];
struct xds
{
    int l,r,mx,cmx;
}t[N<<1];
struct bian
{
    int u,v,w;
}a[N*3];
bool cmp(const bian &a,const bian &b)
{
    return a.w<b.w;
}
int read()
{
    int r=0,f=1;
    char p=getchar();
    while(p>'9'||p<'0')
    {
        if(p=='-')
            f=-1;
        p=getchar();
    }
    while(p>='0'&&p<='9')
    {
        r=r*10+p-48;
        p=getchar();
    }
    return r*f;
}
inline int zhao(int x)
{
    return x==f[x]?x:f[x]=zhao(f[x]);
}
void add(int u,int v,int w)
{
    cnt++;
    e[cnt].ne=h[u];
    e[cnt].to=v;
    e[cnt].va=w;
    h[u]=cnt;
}
void dfs1(int u,int fat)
{
    fa[u]=fat;
    de[u]=de[fat]+1;
    si[u]=1;
    for(int i=h[u];i;i=e[i].ne)
        if(e[i].to!=fat)
        {
            va[e[i].to]=e[i].va;
            dfs1(e[i].to,u);
            si[u]+=si[e[i].to];
            if(si[e[i].to]>si[hs[u]])
                hs[u]=e[i].to;
        }
}
void dfs2(int u,int top)
{
    fr[u]=top;
    id[u]=++tmp;
    rl[tmp]=u;
    if(!hs[u])
        return;
    dfs2(hs[u],top);
    for(int i=h[u];i;i=e[i].ne)
        if(e[i].to!=hs[u]&&e[i].to!=fa[u])
            dfs2(e[i].to,e[i].to);
}
void build(int ro,int l,int r)
{
    t[ro].l=l,t[ro].r=r;
    if(l==r)
    {
        t[ro].mx=t[ro].cmx=va[rl[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(ro<<1,l,mid);
    build(ro<<1|1,mid+1,r);
    t[ro].mx=max(t[ro<<1].mx,t[ro<<1|1].mx);
    if(t[ro<<1].mx==t[ro<<1|1].mx)
    {
        if(max(t[ro<<1].cmx,t[ro<<1|1].cmx)==t[ro<<1].mx)
            t[ro].cmx=min(t[ro<<1].cmx,t[ro<<1|1].cmx);
        else
            t[ro].cmx=max(t[ro<<1].cmx,t[ro<<1|1].cmx);
    }
    else
        t[ro].cmx=min(t[ro<<1].mx,t[ro<<1|1].mx);
}
int ques(int ro,int l,int r,int w)
{
    if(t[ro].l==l&&t[ro].r==r)
        return t[ro].mx==w?t[ro].cmx:t[ro].mx;
    int mid=(t[ro].l+t[ro].r)>>1;
    if(r<=mid)
        return ques(ro<<1,l,r,w);
    else if(l>mid)
        return ques(ro<<1|1,l,r,w);
    else
    {
        int x=ques(ro<<1,l,mid,w),y=ques(ro<<1|1,mid+1,r,w);
        return (max(x,y)==w)?min(x,y):max(x,y);
    }
}
int wen(int u,int v,int w)
{
    int re=0;
    while(fr[u]!=fr[v])
    {
        if(de[fr[u]]<de[fr[v]])
            swap(u,v);
        re=max(re,ques(1,id[fr[u]],id[u],w));
        u=fa[fr[u]];
    }
    if(u!=v)
    {
        if(de[u]>de[v])
            swap(u,v);
        re=max(re,ques(1,id[u]+1,id[v],w));
    }
    if(re==w)
        return 0;
    return re;
}
int main()
{
    n=read(),m=read();
    for(int i=1;i<=m;i++)
        a[i].u=read(),a[i].v=read(),a[i].w=read();
    sort(a+1,a+1+m,cmp);
    for(int i=1;i<=n;i++)
        f[i]=i;
    for(int i=1;i<=m&&con<n-1;i++)
    {
        int fu=zhao(a[i].u),fv=zhao(a[i].v);
        if(fu!=fv)
        {
            f[fu]=fv,con++,sum+=a[i].w;
            add(a[i].u,a[i].v,a[i].w),add(a[i].v,a[i].u,a[i].w);
            mk[i]=1;
        }
    }
    dfs1(1,0);
    dfs2(1,1);
    build(1,1,n);
    for(int i=1;i<=m;i++)
        if(!mk[i])
            ans=min(ans,sum-wen(a[i].u,a[i].v,a[i].w)+a[i].w);
    printf("%lld\n",ans);
    return 0;
}

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325124054&siteId=291194637