RocketMQ Consumer 负载均衡算法源码学习 -- AllocateMessageQueueConsistentHash

RocketMQ 提供了一致性hash 算法来做Consumer 和 MessageQueue的负载均衡。 源码中一致性hash 环的实现是很优秀的,我们一步一步分析。

  1. 一个Hash环包含多个节点, 我们用 MyNode 去封装节点, 方法 getKey() 封装获取节点的key。我们可以实现MyNode 去描述一个物理节点或虚拟节点。MyVirtualNode 实现 MyNode, 表示一个虚拟节点。这里注意:一个虚拟节点是依赖于一个物理节点,所以MyVirtualNode 中封装了 一个 泛型 T physicalNode。物理节点MyClientNode也是实现了这个MyNode接口,很好的设计。代码加注释如下:

     /**
      * 表示hash环的一个节点
      */
     public interface MyNode {
     
         /**
          * @return 节点的key
          */
         String getKey();
     }
    
     		/**
      * 虚拟节点
      */
     public class MyVirtualNode<T extends MyNode> implements MyNode {
     
     final T physicalNode;  // 主节点
     final int replicaIndex;  // 虚节点下标
     
         public MyVirtualNode(T physicalNode, int replicaIndex) {
             this.physicalNode = physicalNode;
             this.replicaIndex = replicaIndex;
         }
     
         @Override
         public String getKey() {
             return physicalNode.getKey() + "-" + replicaIndex;
         }
     
         /**
          * thisMyVirtualNode 是否是pNode 的 虚节点
          */
         public boolean isVirtualNodeOf(T pNode) {
             return physicalNode.getKey().equals(pNode.getKey());
         }
     
         public T getPhysicalNode() {
             return physicalNode;
         }
     }
         private static class MyClientNode implements MyNode {
             private final String clientID;
             public MyClientNode(String clientID) {
                 this.clientID = clientID;
             }
             @Override
             public String getKey() {
                 return clientID;
             }
         }
    
  2. 上面实现了节点, 一致性hash 下一个问题是怎么封装hash算法呢?RocketMQ 使用 MyHashFunction 接口定义hash算法。使用MD5 + bit 位hash的方式实现hash算法。我们完全可以自己实现hash算法,具体见我的“常见的一些hash函数”文章。MyMD5Hash 算法代码的如下:

     // MD5 hash 算法, 这里hash算法可以用常用的 hash 算法替换。
         private static class MyMD5Hash implements MyHashFunction {
             MessageDigest instance;
             public MyMD5Hash() {
                 try {
                     instance = MessageDigest.getInstance("MD5");
                 } catch (NoSuchAlgorithmException e) {
                 }
             }
     
             @Override
             public long hash(String key) {
                 instance.reset();
                 instance.update(key.getBytes());
                 byte[] digest = instance.digest();
     
                 long h = 0;
                 for (int i = 0; i < 4; i++) {
                     h <<= 8;
                     h |= ((int)digest[i]) & 0xFF;
                 }
                 return h;
             }
         }
    
  3. 现在,hash环的节点有了, hash算法也有了,最重要的是描述一个一致性hash 环。 想一想,这个环可以由N 个物理节点, 每个物理节点对应m个虚拟节点,节点位置用hash算法值描述。每个物理节点就是每个Consumer, 每个Consumer 的 id 就是 物理节点的key。 每个MessageQueue 的toString() 值 hash 后,用来找环上对应的最近的下一个物理节点。源码如下,这里展示主要的代码,其中最巧妙地是routeNode 方法, addNode 方法 注意我的注释:

    public class MyConsistentHashRouter<T extends MyNode> {
    
    private final SortedMap<Long, MyVirtualNode<T>> ring = new TreeMap<>(); // key是虚节点key的哈希值, value 是虚节点
    private final MyHashFunction myHashFunction;
    /**
     * @param pNodes 物理节点集合
     * @param vNodeCount 每个物理节点对应的虚节点数量
     * @param hashFunction hash 函数 用于 hash 各个节点
     */
    public MyConsistentHashRouter(Collection<T> pNodes, int vNodeCount, MyHashFunction hashFunction) {
        if (hashFunction == null) {
            throw new NullPointerException("Hash Function is null");
        }
        this.myHashFunction = hashFunction;
        if (pNodes != null) {
            for (T pNode : pNodes) {
                this.addNode(pNode, vNodeCount);
            }
        }
    }
    /**
     * 添加物理节点和它的虚节点到hash环。
     * @param pNode 物理节点
     * @param vNodeCount 虚节点数量。
     */
    public void addNode(T pNode, int vNodeCount) {
        if (vNodeCount < 0) {
            throw new IllegalArgumentException("ill virtual node counts :" + vNodeCount);
        }
        int existingReplicas = this.getExistingReplicas(pNode);
        for (int i = 0; i < vNodeCount; i++) {
            MyVirtualNode<T> vNode = new MyVirtualNode<T>(pNode, i + existingReplicas); // 创建一个新的虚节点,位置是 i+existingReplicas
            ring.put(this.myHashFunction.hash(vNode.getKey()), vNode); // 将新的虚节点放到hash环中
        }
    }
    /**
     * 根据一个给定的key 在 hash环中 找到离这个key最近的下一个物理节点
     * @param key 一个key, 用于找这个key 在环上最近的节点
     */
    public T routeNode(String key) {
        if (ring.isEmpty()) {
            return null;
        }
        Long hashVal = this.myHashFunction.hash(key);
        SortedMap<Long, MyVirtualNode<T>> tailMap = ring.tailMap(hashVal);
        Long nodeHashVal = !tailMap.isEmpty() ? tailMap.firstKey() : ring.firstKey();
        return ring.get(nodeHashVal).getPhysicalNode();
    }
    
    /**
     * @param pNode 物理节点
     * @return 当前这个物理节点对应的虚节点的个数
     */
    public int getExistingReplicas(T pNode) {
        int replicas = 0;
        for (MyVirtualNode<T> vNode : ring.values()) {
            if (vNode.isVirtualNodeOf(pNode)) {
                replicas++;
            }
        }
        return replicas;
    }
    
  4. 现在一致性hash 环有了, 剩下的就是 和rocketmq 的 consumer, mq 构成负载均衡策略了。比较简单, 代码如下:

     			/**
     	 * 基于一致性性hash环的Consumer负载均衡.
     	*/	 
    
     public class MyAllocateMessageQueueConsistentHash implements AllocateMessageQueueStrategy {
     
         // 每个物理节点对应的虚节点的个数
         private final int virtualNodeCnt;
         private final MyHashFunction customHashFunction;
     
         public MyAllocateMessageQueueConsistentHash() {
             this(10);   // 默认10个虚拟节点
         }
     
         public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt) {
             this(virtualNodeCnt, null);
     
         }
         public MyAllocateMessageQueueConsistentHash(int virtualNodeCnt, MyHashFunction customHashFunction) {
             if (virtualNodeCnt < 0) {
                 throw new IllegalArgumentException("illegal virtualNodeCnt : " + virtualNodeCnt);
             }
             this.virtualNodeCnt = virtualNodeCnt;
             this.customHashFunction = customHashFunction;
         }
     
         @Override
         public List<MessageQueue> allocate(String consumerGroup, String currentCID, List<MessageQueue> mqAll, List<String> cidAll) {
             // 省去一系列非空校验
             Collection<MyClientNode> cidNodes = new ArrayList<>();
             for (String cid : cidAll) {
                 cidNodes.add(new MyClientNode(cid));
             }
             final MyConsistentHashRouter<MyClientNode> router;
             if (this.customHashFunction != null) {
                 router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt, customHashFunction);
             }else {
                 router = new MyConsistentHashRouter<MyClientNode>(cidNodes, virtualNodeCnt);
             }
             List<MessageQueue> results = new ArrayList<MessageQueue>();  // 当前 currentCID 对应的 mq
             // 将每个mq 根据一致性hash 算法找到对应的物理节点(Consumer)
             for (MessageQueue mq : mqAll) {
                 MyClientNode clientNode = router.routeNode(mq.toString());   // 根据 mq toString() 方法做hash 和环上节点比较
                 if (clientNode != null && currentCID.equals(clientNode.getKey())) {
                     results.add(mq);
                 }
             }
             return results;
         }
     
         @Override
         public String getName() {
             return "CONSISTENT_HASH";
         }
     
         private static class MyClientNode implements MyNode {
             private final String clientID;
             public MyClientNode(String clientID) {
                 this.clientID = clientID;
             }
             @Override
             public String getKey() {
                 return clientID;
             }
         }
     
     }
    

猜你喜欢

转载自blog.csdn.net/ZHANGYONGHAO604/article/details/82426373