学习笔记--最小费用流之原始对偶

兼具zkw和spfa的优点,折中的一种算法,通过spfa跑出最短路,然后更改边的权值(加上dis【from】-dis【to】),那么如果为0就是在from到to最短路上的点,相当于一种分层(个人理解),就可以用多路增广来搞了。而这里的spfa除了第一次外甚至可以拿dij来替换,不过因为加了slf优化的spfa已经还不错了就没去写。(ps:网上大多spfa都要从汇点跑到源点,个人觉得从源到汇也没什么区别a,有知道的大佬能来解释下吗。。)

代码:(分配工作)

#pragma GCC optimize(3,"inline","Ofast")
#include<bits/stdc++.h>
using namespace std;
const int N=1010,M=250010;
const int s=0,t=1005;
void read(int &x)
{
    char c=getchar();x=0;
    while(!isdigit(c))c=getchar();
    while(isdigit(c))x=(x<<3)+(x<<1)+c-48,c=getchar();
}
int n,m,k,hd[N],nxt[M*3],to[M*3],cost[M*3],las[M*3],tot=-1,mp[510][510];
int cur[N],dis[N],q[M],D=0,ans=0;
bool vis[N],inq[N];
void add(int u,int v,int w,int c)
{
    nxt[++tot]=hd[u],to[tot]=v,las[tot]=w,cost[tot]=c,hd[u]=tot;
    nxt[++tot]=hd[v],to[tot]=u,las[tot]=0,cost[tot]=-c,hd[v]=tot;
}
bool spfa()
{
    int nw,hed=N,tail=N+1;
    memset(dis,127,sizeof dis);
    memset(inq,0,sizeof inq);
    dis[s]=0,q[++hed]=s,inq[s]=1;
    while(hed>=tail)
    {
        nw=q[tail++];
        for(int i=hd[nw];i!=-1;i=nxt[i])
        {
            if(las[i]&&dis[to[i]]>dis[nw]+cost[i])
            {
                dis[to[i]]=dis[nw]+cost[i];
                if(!inq[to[i]])
                {
                    if(dis[to[i]]<dis[q[tail]])q[--tail]=to[i];
                    else q[++hed]=to[i];
                }
                inq[to[i]]=1;
            }
        }
        inq[nw]=0;
    }
    for(int i=0;i<=t;i++)
        for(int j=hd[i];j!=-1;j=nxt[j])
            cost[j]-=dis[to[j]]-dis[i];
    D+=dis[t];
    return dis[t]<=2e9;
}
int dfs(int pos,int flow)
{
    if(pos==t)return ans+=flow*D,flow;
    vis[pos]=1;
    int l=flow,tp;
    for(int &i=cur[pos];i!=-1;i=nxt[i])
    {
        if(las[i]&&!cost[i]&&!vis[to[i]])
        {
            tp=dfs(to[i],min(l,las[i]));
            las[i]-=tp,las[i^1]+=tp;
            l-=tp;
            if(!l)return flow;
        }
    }
    return flow-l;
}
void mcmf()
{
    while(spfa())
    {
        do{
            memset(vis,0,sizeof vis);
            for(int i=s;i<=t;i++)cur[i]=hd[i];
        }while(dfs(s,2e9));
    }
    printf("%d",-ans);
}
int main()
{
    int u,v,w;
    memset(hd,-1,sizeof hd);
    memset(nxt,-1,sizeof nxt);
    memset(mp,127,sizeof 127);
    read(n),read(m),read(k);
    for(int i=1;i<=k;i++)
    {
        read(u),read(v),read(w);
        mp[v][u]=min(mp[v][u],-w);
    }
    for(int i=1;i<=m;i++)
        add(s,i,1,0);
    for(int i=1;i<=n;i++)
        add(i+m,t,1,0);
    for(int i=1;i<=m;i++)
        for(int j=1;j<=n;j++)
            if(mp[i][j]<=2e9)add(i,j+m,1,mp[i][j]);
    mcmf();
}

猜你喜欢

转载自blog.csdn.net/caoyang1123/article/details/82619833