[学习笔记]虚树

模板:(树剖\(LCA\)+建虚树)

#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;

struct node{
    int to,next;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y){
    e[++tot].to=y;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

int main()
{
    n=read();
    int x,y,w,k,lca;
    for(int i=1;i<n;i++){
        x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top--]);
                    if(sta[Top]!=lca) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        for(int i=1;i<=k;i++) vis[h[i]]=0;
    }
    return 0;
}

具体建虚树怎么建可以看别人的博客……我讲的肯定没有它们好

1、[SDOI2011]消耗战

分析:人生第一道虚树题。

难的就是建一棵虚树,然后在虚树上树形 \(dp\)

首先,打出一个树上前缀最小值。因为无论怎样,选一条最小的边断掉一定是最优的。建一棵虚树,若遍历到选定的点 \(x\),那么 \(dp[x]=min(dis[x],\sum_{son\in x}val_{x->son})\),其中 \(val\) 为边权。

\(Code\ Below:\)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=250000+10;
const int inf=1e18;
int n,m,dp[maxn],dis[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;

struct node{
    int to,next,val;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y,int w){
    e[++tot].to=y;
    e[tot].val=w;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dis[y]=min(dis[x],e[i].val);
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

void dfs(int x,int flag){
    dp[x]=dis[x];
    if(flag){
        for(int i=fir[x];i;i=nxt[i])
            dfs(to[i],flag);
        fir[x]=vis[x]=0;
        return ;
    }
    int val=0;
    for(int i=fir[x],y;i;i=nxt[i]){
        y=to[i];
        dfs(y,vis[y]);
        val+=dp[y];
    }
    if(!fir[x]||vis[x]) val=inf;
    dp[x]=min(dp[x],val);
    fir[x]=vis[x]=0;
}

signed main()
{
    n=read();
    int x,y,w,k,lca;
    for(int i=1;i<n;i++){
        x=read(),y=read(),w=read();
        add(x,y,w);add(y,x,w);
    }
    dis[1]=inf;
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;
        sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top]);
                    if(lca!=sta[--Top]) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        dfs(1,0);
        printf("%lld\n",dp[1]);
    }
    return 0;
}

2、[HEOI2014]大工程

分析:这道题自己推的,很有成就感哈哈哈

方法与上题一样,不过多一点细节

\(sub[x]\) 表示在虚树上 \(x\) 的子树内有多少个选定点

这些点对 \((x,y)\) 对答案的贡献要分两类讨论:

1、\(x=lca(x,y)\) ,那么直接在 \(vis[x]=1\) 的时候算掉

2、\(x,y\) 在两棵不同的子树内,那就一边更新 \(sub[x]\) 一边算

void dfs(int x){
    int now=0;
    for(int i=fir[x];i;i=nxt[i]){
        dfs(to[i]);
        now+=sub[x]*sub[to[i]];
        sub[x]+=sub[to[i]];
        sub[to[i]]=0;
    }
    ans-=2*now*dep[x];
    if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++;
}

找最小边权就是记录一下最小值和次小值,然后更新 \(ans\)

找最大边权同个道理

int dfs_min(int x){
    int Min=inf,sec=inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=min(sec,dfs_min(to[i]));
        if(Min>sec) swap(Min,sec);
    }
    if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
    if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
    if(vis[x]) Min=dep[x];
    return Min;
}

int dfs_max(int x){
    int Max=-inf,sec=-inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=max(sec,dfs_max(to[i]));
        if(Max<sec) swap(Max,sec);
    }
    if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
    if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
    if(vis[x]&&Max==-inf) Max=dep[x];
    fir[x]=0;
    return Max;
}

那个前式链向星数组 \(fir[x]\) 一定要在 \(dfsmax()\) 的时候清空!!!

\(Code\ Below:\)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1000000+10;
const int inf=1e18;
int n,m,dp[maxn],vis[maxn],h[maxn],sub[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim,ans;

struct node{
    int to,next;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y){
    e[++tot].to=y;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

void dfs(int x){
    int now=0;
    for(int i=fir[x];i;i=nxt[i]){
        dfs(to[i]);
        now+=sub[x]*sub[to[i]];
        sub[x]+=sub[to[i]];
        sub[to[i]]=0;
    }
    ans-=2*now*dep[x];
    if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++;
    //printf("x=%lld,now=%lld,ans=%lld,sub[x]=%lld,dep[x]=%lld\n",x,now,ans,sub[x],dep[x]);
}

int dfs_min(int x){
    int Min=inf,sec=inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=min(sec,dfs_min(to[i]));
        if(Min>sec) swap(Min,sec);
    }
    if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
    if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
    if(vis[x]) Min=dep[x];
    return Min;
}

int dfs_max(int x){
    int Max=-inf,sec=-inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=max(sec,dfs_max(to[i]));
        if(Max<sec) swap(Max,sec);
    }
    if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
    if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
    if(vis[x]&&Max==-inf) Max=dep[x];
    fir[x]=0;
    return Max;
}

signed main()
{
    n=read();
    int x,y,w,k,lca;
    for(int i=1;i<n;i++){
        x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        if(k==1){
            printf("0 0 0\n");
            continue;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top--]);
                    if(sta[Top]!=lca) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        ans=0;
        for(int i=1;i<=k;i++) ans+=(k-1)*dep[h[i]];
        dfs(1);sub[1]=0;
        printf("%lld ",ans);
        ans=inf;dfs_min(1);
        printf("%lld ",ans);
        ans=-inf;dfs_max(1);
        printf("%lld\n",ans);
        for(int i=1;i<=k;i++) vis[h[i]]=0;
    }
    return 0;
}

3、CF613D Kingdom and its Cities

分析:在虚树上树形 \(dp\) 的时候分三种情况:

\(P.S:sum\) 表示有多少个儿子已经选了

1、\(vis[x]=1\),那么不能选 \(x\) 来断掉儿子的退路,那么 \(dp[x]=\sum_{son\in x} dp[son]\)

2、\(vis[x]=0,sum>1\),那就直接选 \(x\)\(x\) 的子树已经被 \(x\) 封死了

3、\(vis[x]=0,sum\leq 1\),那就传到 \(x\) 的父亲上,让 \(x\) 的父亲解决好了

\(Code\ Below:\)

#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;

struct node{
    int to,next;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y){
    e[++tot].to=y;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

int dfs(int x){
    int ans=0,sum=0;
    for(int i=fir[x];i;i=nxt[i])
        ans+=dfs(to[i]),sum+=dp[to[i]];
    if(vis[x]) dp[x]=1,ans+=sum;
    else if(sum>1) dp[x]=0,ans++;
    else dp[x]=sum;
    fir[x]=0;
    return ans;
}

int main()
{
    n=read();
    int x,y,w,k,lca,flag;
    for(int i=1;i<n;i++){
        x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        flag=0;
        for(int i=1;i<=k;i++)
            flag|=vis[fa[h[i]]];
        if(flag){
            printf("-1\n");
            for(int i=1;i<=k;i++) vis[h[i]]=0;
            continue;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top--]);
                    if(sta[Top]!=lca) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        printf("%d\n",dfs(1));
        for(int i=1;i<=k;i++) vis[h[i]]=0;
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/owencodeisking/p/9965535.html