CF786E ALT 最小割+倍增lca

版权声明:2333 https://blog.csdn.net/liangzihao1/article/details/81912729

题目大意:
给你一棵 n ( n <= 2 10 4 ) 个点的树和 m ( m <= 10 4 ) 个人,第 i 个人要从 a i b i
你可以选择给一个人一只小狗或者在树上一条边上放一只小狗。
如果一个人路径上所有边都有小狗,或者给他一只小狗,那么这个人就是高兴的。
求最少需要多少条狗才能人所有人都是高兴的,并输出一组解。

分析:
我们要从两个方案中选择一个,可以考虑最小割。
一条从 x y 流量为 w 的边表示为 ( x , y , w )
对于每个人建一个点 i ,连一条 ( S , i , 1 ) 的边。
每条边也建一个点 j ,连一条 ( j , T , 1 ) 的边。
每个人 i 向他路径上的边 j 连一条 ( i , j , i n f ) 的边。
这样可能边数会很多。比如说树是一条链,每条路径都是从链的一段到另一段,这样每个 i 都要向所有点连边,空间上无法承受。所以我们考虑倍增连边,就是大区间向两个小区间连一条 i n f 的边。这样一个点最多只会连 O ( l o g n ) 条边,然后跑网络流。
至于输出方案,我们从S对残余网络 b f s 。如果一条 ( S , i , 1 ) 的边满流,而且 S 无法到达 i ,那么这条边被割,也就是给 i 一只小狗;如果我们能到达 j ,根据最小割的原理,那么 ( j , T , 1 ) 一定被割,表示这条边上放一只小狗。

代码:

#include <iostream>
#include <cstdio>
#include <cmath>
#include <vector>
#include <queue>

const int maxn=4e5+7;
const int inf=0x3f3f3f3f;

using namespace std;

int n,m,x,y,num,cnt,s,t,ans;
int ls[maxn],f[maxn][15],h[maxn][15],dep[maxn],dis[maxn],po[maxn],vis[maxn];

struct edge{
    int y,w,op,next;
}g[maxn*20];

struct rec{
    int y,num;
};

vector <rec> e[maxn];
vector <int> a,b;
queue <int> q;

void dfs(int x,int fa)
{
    f[x][0]=fa;
    dep[x]=dep[fa]+1;
    for (int i=0;i<e[x].size();i++)
    {
        int y=e[x][i].y;
        if (y==fa) continue;
        h[y][0]=++num;
        po[num]=e[x][i].num;
        dfs(y,x);
    }
}

void add(int x,int y,int w)
{
    g[++cnt]=(edge){y,w,cnt+1,ls[x]};
    ls[x]=cnt;
    g[++cnt]=(edge){x,0,cnt-1,ls[y]};
    ls[y]=cnt;
}

void add_link(int x,int y,int p)
{
    if (dep[x]>dep[y]) swap(x,y);
    int d=dep[y]-dep[x],k=14,t=1<<14;
    while (d)
    {
        if (d>=t)
        {
            d-=t;
            add(p,h[y][k],inf);
            y=f[y][k];
        }
        k--;
        t/=2;
    }
    if (x==y) return;
    k=14;
    while (k>=0)
    {
        if (f[x][k]!=f[y][k])
        {
            add(p,h[x][k],inf);
            add(p,h[y][k],inf);
            x=f[x][k];
            y=f[y][k];
        }
        k--;
    }
    add(p,h[x][0],inf);
    add(p,h[y][0],inf);
}

bool bfs()
{
    for (int i=s;i<=t;i++) dis[i]=inf;
    dis[s]=0;
    q.push(s);
    while (!q.empty())
    {
        int x=q.front();
        q.pop();
        for (int i=ls[x];i>0;i=g[i].next)
        {
            int y=g[i].y;
            if ((g[i].w) && (dis[y]>dis[x]+1))
            {
                dis[y]=dis[x]+1;
                q.push(y);
            }
        }
    }
    return (dis[t]!=inf);
}

int dinic(int x,int maxf)
{
    if ((x==t) || (!maxf)) return maxf;
    int ret=0;
    for (int i=ls[x];i>0;i=g[i].next)
    {
        int y=g[i].y;
        if ((g[i].w) && (dis[y]==dis[x]+1))
        {
            int f=dinic(y,min(maxf-ret,g[i].w));
            if (!f) dis[y]=0;
            ret+=f;
            g[i].w-=f;
            g[g[i].op].w+=f;
        }
    }
    return ret;
}

void calc()
{   
    q.push(s);
    vis[s]=1;
    while (!q.empty())
    {
        int x=q.front();
        q.pop();
        for (int i=ls[x];i>0;i=g[i].next)
        {
            int y=g[i].y;
            if ((g[i].w) && (!vis[y]))
            {
                vis[y]=1;
                q.push(y);
            }
        }
    }
    for (int i=1;i<=m;i++)
    {
        if (!vis[i]) a.push_back(i);
    }
    for (int i=1;i<n;i++)
    {
        if (vis[i+m]) b.push_back(po[i+m]);
    }
}

int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<n;i++)
    {
        scanf("%d%d",&x,&y);
        e[x].push_back((rec){y,i});
        e[y].push_back((rec){x,i});
    }
    s=0;
    num=m;  
    dfs(1,0);           
    for (int j=1;j<15;j++)
    {
        for (int i=1;i<=n;i++)
        {
            f[i][j]=f[f[i][j-1]][j-1];
            if (i!=1) h[i][j]=++num;
            if (h[i][j-1]) add(h[i][j],h[i][j-1],inf);
            if (h[f[i][j-1]][j-1]) add(h[i][j],h[f[i][j-1]][j-1],inf);
        }
    }   
    t=num+1;
    for (int i=1;i<=n;i++) if (i!=1) add(h[i][0],t,1);
    for (int i=1;i<=m;i++)
    { 
        scanf("%d%d",&x,&y);
        add(s,i,1);
        add_link(x,y,i);
    }
    while (bfs()) ans+=dinic(s,inf);
    printf("%d\n",ans);     
    calc();
    printf("%d ",a.size());
    for (int i=0;i<a.size();i++) printf("%d ",a[i]);
    printf("\n");
    printf("%d ",b.size());
    for (int i=0;i<b.size();i++) printf("%d ",b[i]);
}

猜你喜欢

转载自blog.csdn.net/liangzihao1/article/details/81912729