【运筹优化】最大二分匹配问题及两种算法详解 + Java代码实现


一、最大二分匹配问题

1.1 二分图

二分图,又称二部图,英文名叫 Bipartite graph。

二分图是什么?节点由两个集合组成,且两个集合内部没有边的图。

换言之,存在一种方案,将节点划分成满足以上性质的两个集合。

下图展示了一个二分图的示例:

在这里插入图片描述

1.2 最大二分匹配问题介绍

在最大二分匹配(MBM)问题中,给定一个二分图 G,即分左右两部分,两个部分内部的点没有边连接,要求选出一些边,使得这些边没有公共顶点,且边的数量最大。

上面的描述可能比较抽象,我们可以假想成男女配对。如下图所示,男女之间的连线代表他们互相之间有好感,假设你现在是一个媒婆,你的目标是在不违反男女双方自由意志的前提下,尽可能撮合出最多对的情侣。这就是最大二分匹配问题!

在这里插入图片描述


二、匈牙利算法

参考链接:优化 | 二部图最大匹配问题的精确算法详解(HK算法和匈牙利算法):一份让您满意的【理论介绍+代码实现】学习笔记


三、HK算法

参考链接:优化 | 二部图最大匹配问题的精确算法详解(HK算法和匈牙利算法):一份让您满意的【理论介绍+代码实现】学习笔记


四、Java代码实现

4.1 匈牙利算法

import java.util.*;

public class Hungarian {
    
    
    public void solve(HashMap<Object, List<Object>> graph) {
    
    
        // 获取左端点集合
        HashSet<Object> leftSet = new HashSet<>(graph.keySet());
        // 获取右端点集合
        HashSet<Object> rightSet = new HashSet<>();
        // 将graph补充完整
        for (Object left : leftSet) {
    
    
            for (Object right : graph.get(left)) {
    
    
                if (!graph.containsKey(right)) {
    
    
                    graph.put(right, new ArrayList<>());
                    rightSet.add(right);
                }
            }
        }
        System.out.println("leftSet: " + leftSet);
        System.out.println("rightSet: " + rightSet);
        System.out.println("Graph: " + graph);
        // 开始遍历左端点集合,寻找增广路
        while (true) {
    
    
            int expandRouteNum = 0;
            // 记录是否被DFS改变了graph结构
            boolean isDfsChange = false;
            // 开始本次循环,遍历所有free的左节点
            for (Object left : leftSet) {
    
    
                if (!isConnected(graph, left, rightSet)) {
    
    
                    boolean b = false;
                    if (graph.get(left).size() > 0) {
    
    
                        // 直接就能找到增广路
                        for (int i = 0; i < graph.get(left).size(); i++) {
    
    
                            if (graph.get(graph.get(left).get(i)).isEmpty()) {
    
    
                                graph.get(graph.get(left).remove(i)).add(left);
                                b = true;
                                expandRouteNum++;
                                break;
                            }
                        }
                    }
                    // 没找到直接连接的增广路,那就DFS找增广路
                    if (!b && expandRouteNum < rightSet.size()) {
    
    
                        List<Object> path = new ArrayList<>();
                        path.add(left);
                        if (dfs(path, graph, leftSet, rightSet)) {
    
    
                            // 找到了增广路,根据增广路更新graph
                            for (int i = 0; i < path.size() - 1; i++) {
    
    
                                graph.get(path.get(i)).remove(path.get(i + 1));
                                graph.get(path.get(i + 1)).add(path.get(i));
                                expandRouteNum++;
                            }
                            isDfsChange = true;
                        } else {
    
    
                            break;
                        }
                    }
                }
            }
            if (!isDfsChange) {
    
    
                break;
            }
        }
        // 输出结果
        int cnt = 0;
        for (Object key : graph.keySet()) {
    
    
            if (rightSet.contains(key) && !graph.get(key).isEmpty()) {
    
    
                if (graph.get(key).size() > 1) {
    
    
                    throw new RuntimeException();
                }
                System.out.println("Match-" + (++cnt) + ": " + graph.get(key).get(0) + "->" + key);
            }
        }
        System.out.println("最大匹配数为: " + cnt);
    }

    // 判断当前graph有没有right点与该left点相连
    private boolean isConnected(HashMap<Object, List<Object>> graph, Object left, HashSet<Object> rightSet) {
    
    
        for (Object right : rightSet) {
    
    
            if (graph.get(right).contains(left)) {
    
    
                return true;
            }
        }
        return false;
    }

    // DFS 寻找增广路
    private boolean dfs(List<Object> path, HashMap<Object, List<Object>> graph, HashSet<Object> leftSet, HashSet<Object> rightSet) {
    
    
        Object curNode = path.get(path.size() - 1);
        // 当前路径是否为增广路的判断
        if (rightSet.contains(curNode) && graph.get(curNode).isEmpty()) {
    
    
            return true;
        }
        for (Object nextNode : graph.keySet()) {
    
    
            if (graph.get(curNode).contains(nextNode) && !isExitInPath(path, nextNode)) {
    
    
                if (leftSet.contains(curNode) && rightSet.contains(nextNode)) {
    
    
                    // 当前点是左边,nextNode是右边
                    path.add(nextNode);
                    if (dfs(path, graph, leftSet, rightSet)) {
    
    
                        return true;
                    }
                    path.remove(path.size() - 1);
                } else if (rightSet.contains(curNode) && leftSet.contains(nextNode)) {
    
    
                    // 当前点是右边,nextNode是左边
                    path.add(nextNode);
                    if (dfs(path, graph, leftSet, rightSet)) {
    
    
                        return true;
                    }
                    path.remove(path.size() - 1);
                }
            }
        }
        return false;
    }

    // 判断路径中是否出现过该点
    private boolean isExitInPath(List<Object> path, Object node) {
    
    
        for (Object o : path) {
    
    
            if (o.equals(node)) {
    
    
                return true;
            }
        }
        return false;
    }

}

4.2 HK 算法

import java.util.*;

public class HopcroftKarp {
    
    
    public void solve(HashMap<Object, List<Object>> graph) {
    
    
        // 获取左端点集合
        HashSet<Object> leftSet = new HashSet<>(graph.keySet());
        // 获取右端点集合
        HashSet<Object> rightSet = new HashSet<>();
        // 将graph补充完整
        for (Object left : leftSet) {
    
    
            for (Object right : graph.get(left)) {
    
    
                if (!graph.containsKey(right)) {
    
    
                    graph.put(right, new ArrayList<>());
                    rightSet.add(right);
                }
            }
        }
        System.out.println("leftSet: " + leftSet);
        System.out.println("rightSet: " + rightSet);
        System.out.println("Graph: " + graph);
        // 开始循环
        int expandRouteNum = 0;
        while (true) {
    
    
            // 先把能直接连接的点连接了
            boolean b = false;
            for (Object left : leftSet) {
    
    
                if (!isConnected(graph, left, rightSet)) {
    
    
                    if (graph.get(left).size() > 0) {
    
    
                        // 直接就能找到增广路
                        for (int i = 0; i < graph.get(left).size(); i++) {
    
    
                            if (graph.get(graph.get(left).get(i)).isEmpty()) {
    
    
                                graph.get(graph.get(left).remove(i)).add(left);
                                b = true;
                                expandRouteNum++;
                                break;
                            }
                        }
                    }
                }
            }
            if (!b) {
    
    
                // BFS 获取 level graph
                List<Set<Object>> levelGraph = buildLevelGraph(graph, leftSet, rightSet);
                if (levelGraph.size() > 2) {
    
    
                    // DFS 在 level graph 中搜索多条可能存在的增广路径
                    List<List<Object>> pathList = new ArrayList<>();
                    Set<Object> visited = new HashSet<>();
                    for (Object firstLevelNode : levelGraph.get(0)) {
    
    
                        if (!visited.contains(firstLevelNode)) {
    
    
                            List<Object> curPath = new ArrayList<>();
                            curPath.add(firstLevelNode);
                            List<Object> findPath = dfs(1, curPath, levelGraph, graph, visited, rightSet);
                            if (findPath != null) {
    
    
                                pathList.add(findPath);
                                visited.addAll(findPath);
                            }
                        }
                    }
                    if (pathList.isEmpty()) {
    
    
                        break;
                    } else {
    
    
                        // 根据找到的增广路更新 graph
                        for (List<Object> path : pathList) {
    
    
                            for (int i = 0; i < path.size() - 1; i++) {
    
    
                                graph.get(path.get(i)).remove(path.get(i + 1));
                                graph.get(path.get(i + 1)).add(path.get(i));
                                expandRouteNum++;
                            }
                        }
                    }
                } else {
    
    
                    break;
                }
            }
        }
        // 输出结果
        int cnt = 0;
        for (Object key : graph.keySet()) {
    
    
            if (rightSet.contains(key) && !graph.get(key).isEmpty()) {
    
    
                if (graph.get(key).size() > 1) {
    
    
                    throw new RuntimeException();
                }
                System.out.println("Match-" + (++cnt) + ": " + graph.get(key).get(0) + "->" + key);
            }
        }
        System.out.println("最大匹配数为: " + cnt);
    }

    // 在 Level Graph 上进行 DFS ,找出多条增广路
    private List<Object> dfs(int curLevel, List<Object> curPath, List<Set<Object>> levelGraph, HashMap<Object, List<Object>> graph, Set<Object> visited, HashSet<Object> rightSet) {
    
    
        if (curLevel < levelGraph.size()) {
    
    
            for (Object curLevelNode : levelGraph.get(curLevel)) {
    
    
                if (!visited.contains(curLevelNode) && graph.get(curPath.get(curPath.size() - 1)).contains(curLevelNode)) {
    
    
                    curPath.add(curLevelNode);
                    List<Object> findPath = dfs(curLevel + 1, curPath, levelGraph, graph, visited, rightSet);
                    if (findPath != null) {
    
    
                        return findPath;
                    }
                    curPath.remove(curPath.size() - 1);
                }
            }
        } else {
    
    
            if (rightSet.contains(curPath.get(curPath.size() - 1))) {
    
    
                return curPath;
            }
        }
        return null;
    }

    // 构建 Level Graph
    private List<Set<Object>> buildLevelGraph(HashMap<Object, List<Object>> graph, HashSet<Object> leftSet, HashSet<Object> rightSet) {
    
    
        Set<Object> visited = new HashSet<>();
        List<Set<Object>> levelGraph = new ArrayList<>();
        Set<Object> level1 = new HashSet<>();
        for (Object left : leftSet) {
    
    
            if (!isConnected(graph, left, rightSet)) {
    
    
                level1.add(left);
                visited.add(left);
            }
        }
        levelGraph.add(level1);
        while (true) {
    
    
            Set<Object> newLevel = new HashSet<>();
            for (Object lastLevelNode : levelGraph.get(levelGraph.size() - 1)) {
    
    
                for (Object o : graph.get(lastLevelNode)) {
    
    
                    if (visited.add(o)) {
    
    
                        newLevel.add(o);
                    }
                }
            }
            if (!newLevel.isEmpty()) {
    
    
                levelGraph.add(newLevel);
            } else {
    
    
                break;
            }
        }
        boolean b = false;
        for (Object o : levelGraph.get(levelGraph.size() - 1)) {
    
    
            b = leftSet.contains(o);
            break;
        }
        if (b) {
    
    
            levelGraph.remove(levelGraph.size() - 1);
        }
        return levelGraph;
    }

    // 判断当前graph有没有right点与该left点相连
    private boolean isConnected(HashMap<Object, List<Object>> graph, Object left, HashSet<Object> rightSet) {
    
    
        for (Object right : rightSet) {
    
    
            if (graph.get(right).contains(left)) {
    
    
                return true;
            }
        }
        return false;
    }

}

4.3 算法测试

测试案例

在这里插入图片描述

测试代码

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;

public class Test {
    
    

    public static void main(String[] args) {
    
    
        HashMap<Object, List<Object>> graph = new HashMap<>();
        graph.put("A", new ArrayList<>());
        graph.get("A").add(1);
        graph.get("A").add(4);

        graph.put("B", new ArrayList<>());
        graph.get("B").add(3);
        graph.get("B").add(6);
        graph.get("B").add(7);

        graph.put("C", new ArrayList<>());
        graph.get("C").add(2);
        graph.get("C").add(4);
        graph.get("C").add(5);

        graph.put("D", new ArrayList<>());
        graph.get("D").add(2);
        graph.get("D").add(7);

        graph.put("E", new ArrayList<>());
        graph.get("E").add(5);
        graph.get("E").add(6);
        graph.get("E").add(7);

        graph.put("F", new ArrayList<>());
        graph.get("F").add(3);
        graph.get("F").add(6);

        graph.put("G", new ArrayList<>());
        graph.get("G").add(6);
        graph.get("G").add(7);

        System.out.println("------------------ 匈牙利算法 ------------------");
        long s = System.currentTimeMillis();
        new Hungarian().solve(copyGraph(graph));
        System.out.println("用时: " + (System.currentTimeMillis() - s) + " ms");

        System.out.println("------------------ HK 算法 ------------------");
        s = System.currentTimeMillis();
        new HopcroftKarp().solve(copyGraph(graph));
        System.out.println("用时: " + (System.currentTimeMillis() - s) + " ms");
    }

    public static HashMap<Object, List<Object>> copyGraph(HashMap<Object, List<Object>> graph) {
    
    
        HashMap<Object, List<Object>> copyGraph = new HashMap<>();
        for (Object key : graph.keySet()) {
    
    
            copyGraph.put(key, new ArrayList<>(graph.get(key)));
        }
        return copyGraph;
    }

}

结果输出

------------------ 匈牙利算法 ------------------
leftSet: [A, B, C, D, E, F, G]
rightSet: [1, 2, 3, 4, 5, 6, 7]
Graph: {
    
    A=[1, 4], 1=[], B=[3, 6, 7], 2=[], C=[2, 4, 5], 3=[], D=[2, 7], 4=[], E=[5, 6, 7], 5=[], F=[3, 6], 6=[], G=[6, 7], 7=[]}
Match-1: A->1
Match-2: D->2
Match-3: F->3
Match-4: C->4
Match-5: E->5
Match-6: G->6
Match-7: B->7
最大匹配数为: 7
用时: 2 ms
------------------ HK 算法 ------------------
leftSet: [A, B, C, D, E, F, G]
rightSet: [1, 2, 3, 4, 5, 6, 7]
Graph: {
    
    A=[1, 4], 1=[], B=[3, 6, 7], 2=[], C=[2, 4, 5], 3=[], D=[2, 7], 4=[], E=[5, 6, 7], 5=[], F=[3, 6], 6=[], G=[6, 7], 7=[]}
Match-1: A->1
Match-2: D->2
Match-3: B->3
Match-4: C->4
Match-5: E->5
Match-6: F->6
Match-7: G->7
最大匹配数为: 7
用时: 1 ms

可以看到,匈牙利算法和HK算法求得的结果最大匹配数都是7,但是最大匹配不相同,说明测试案例存在多个最大匹配,也说明了HK算法和匈牙利算法逻辑的不同,导致求得的解可能不同。

猜你喜欢

转载自blog.csdn.net/weixin_51545953/article/details/129224713
今日推荐