KM算法【带权二分图完美匹配】

先orz litble——KM算法

为什么要用KM算法

因为有的题丧心病狂卡费用流
KM算法相比于费用流来说,具有更高的效率。

算法流程

我们给每一个点设一个期望值【可行顶标】
对于左边的点来说,就是期望能匹配到多大权值的右边的点
对于右边的点来说,就是期望能在左边的点的期望之上还能产生多少贡献

两个点能匹配,当且仅当它们的期望值之和为这条边的权值

一开始初始化所有左点的期望是其出边的最大值,因为最理想情况下当然是每个点都匹配自己能匹配最大的那个
右点期望为0

然后我们逐个匹配,当一个点匹配失败时,所有左点的期望就过高了
我们从右点未匹配的点中找到离被匹配相差的期望最小的点,所有已匹配左点减去这个期望值【使得能匹配的点多出了一个】,然后已匹配的右点要加上这个期望值【因为还要保证已匹配的点仍然能被匹配】
然后继续尝试匹配
直至所有点匹配完

板子:

#include<iostream>
#include<cstdio>
#include<cmath>
#include<cstring>
#include<algorithm>
#define LL long long int
#define Redge(u) for (int k = h[u],to; k; k = ed[k].nxt)
#define REP(i,n) for (int i = 1; i <= (n); i++)
#define BUG(s,n) for (int i = 1; i <= (n); i++) cout<<s[i]<<' '; puts("");
using namespace std;
const int maxn = 405,maxm = 100005,INF = 1000000000;
inline int read(){
    int out = 0,flag = 1; char c = getchar();
    while (c < 48 || c > 57){if (c == '-') flag = -1; c = getchar();}
    while (c >= 48 && c <= 57){out = (out << 3) + (out << 1) + c - 48; c = getchar();}
    return out * flag;
}
int w[maxn][maxn],expa[maxn],expb[maxn],visa[maxn],visb[maxn],cp[maxn],dl[maxn];
int n;
bool dfs(int u){
    visa[u] = true;
    REP(i,n) if (!visb[i]){
        int kl = expa[u] + expb[i] - w[u][i];
        if (kl == 0){
            visb[i] = true;
            if (!cp[i] || dfs(cp[i])){
                cp[i] = u; return true;
            }
        }
        else dl[i] = min(dl[i],kl);
        
    }
    return false;
}
int solve(){
    REP(i,n) expa[i] = expb[i] = cp[i] = 0;
    REP(i,n) REP(j,n) expa[i] = max(expa[i],w[i][j]);
    REP(i,n){
        REP(j,n) dl[j] = INF;
        while (true){
            REP(j,n) visa[j] = false,visb[j] = false;
            if (dfs(i)) break;
            int kl = INF;
            REP(j,n) if (!visb[j]) kl = min(kl,dl[j]);
            REP(j,n){
                if (visa[j]) expa[j] -= kl;
                if (visb[j]) expb[j] += kl;
                else dl[j] -= kl;
            }
        }
    }
    int re = 0;
    REP(i,n) re += w[cp[i]][i];
    return re;
}
int main(){
    n = read();
    REP(i,n) REP(j,n) w[i][j] = read();
    printf("%d\n",solve());
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/Mychael/p/8994980.html
今日推荐