kitex 中 consistent hashing 的实现

一致性哈希算法(consistent hashing)

kitex 中一致性的很多细节都和我预先理解的不一样。

  • 这种负载均衡算法是在client侧实现的,那么client是怎么知道所有的ip的? 感觉这种算法应该是做一个中间件比较好,client请求实现一致性hash的中间件,中间件依据一致性hash算法来选取节点返回ip port,client侧应当不关注路由算法才对。
  • 算法中环是用数组实现的,这不奇怪。但节点的路由是二分查找找到的,这就有点奇怪。
  • consistent hashing 在consist.go中,很难理解这个文件为什么不叫consistent.go
  • buildWeightedVirtualNodes和buildVirtualNodes的代码长相惊人的类似,一看就是ctrl c v后懒得改

客户端测试代码:

package main

import (
	"context"
	"fmt"
	"log"
	"time"

	"github.com/cloudwego/kitex-examples/kitex_gen/api"
	"github.com/cloudwego/kitex-examples/kitex_gen/api/echo"
	"github.com/cloudwego/kitex/client"
	"github.com/cloudwego/kitex/pkg/loadbalance"
)

type ctxKey int

const (
	ctxConsistentKey ctxKey = iota
)

func main() {
    
    
	opt := loadbalance.NewConsistentHashOption(func(ctx context.Context, request interface{
    
    }) string {
    
    
		key, _ := ctx.Value(ctxConsistentKey).(string)
		return key
	})
	opt.Weighted = true
	lb := loadbalance.NewConsistBalancer(opt)
	client, err := echo.NewClient("echo", client.WithHostPorts("0.0.0.0:8801", "0.0.0.0:8802"), client.WithLoadBalancer(lb))
	if err != nil {
    
    
		log.Fatal(err)
	}

	var i int = 0
	for {
    
    
		// call a server
		ctx := context.Background()
		ctx = context.WithValue(ctx, ctxConsistentKey, "my key"+fmt.Sprintf("%d", i))
		req := &api.Request{
    
    Message: "my request" + fmt.Sprintf("%d", i)}
		resp, err := client.Echo(ctx, req)
		if err != nil {
    
    
			log.Fatal(err)
		}
		log.Println("call id :"+fmt.Sprintf("%d", i), resp)
		time.Sleep(time.Millisecond)
		// time.Sleep(time.Second)
		i++

	}

}
// 这一行进行配置,这里没有进行到后端的路由
	client, err := echo.NewClient("echo", client.WithHostPorts("0.0.0.0:8801", "0.0.0.0:8802"), client.WithLoadBalancer(lb))

// 调用时才进行路由
		resp, err := client.Echo(ctx, req)

调用堆栈:

在这里插入图片描述

TODO:客户端有缓存,如果所路由到的对端挂掉,这个缓存会清空重建吗? 这块没看
在这里插入描述

这里依据是否根据权重来判断,buildWeightedVirtualNodes 和buildVirtualNodes 函数没处理好,冗余代码比较多

func (cb *consistBalancer) buildNodes(ins []discovery.Instance) ([]realNode, []virtualNode) {
    
    
	ret := make([]realNode, len(ins))
	for i := range ins {
    
    
		ret[i].Ins = ins[i]
	}
	if cb.opt.Weighted {
    
    
		return ret, cb.buildWeightedVirtualNodes(ret)
	}
	return ret, cb.buildVirtualNodes(ret)
}

建立虚拟节点:(代码有改动)


// build virtual nodes
func (cb *consistBalancer) buildWeightedVirtualNodes(rNodes []realNode) []virtualNode {
    
    
	if len(rNodes) == 0 {
    
    
		return []virtualNode{
    
    }
	}
	vlen := 0
	for i := range rNodes {
    
    
		//                      10                   100
		vlen += rNodes[i].Ins.Weight() * int(cb.opt.VirtualFactor)
	}

	//                           2000
	ret := make([]virtualNode, vlen)
	if vlen == 0 {
    
    
		return ret
	}
	maxLen := 0
	for i := range rNodes {
    
    
		// TODO 优化 代码难看,使用 max
		if len(rNodes[i].Ins.Address().String()) > maxLen {
    
    
			maxLen = len(rNodes[i].Ins.Address().String())
		}
	}
	// l-> length
	l := maxLen + 1 + cb.opt.virtualFactorLen // "$address + # + itoa(i)"
	// pre-allocate []byte here, and reuse it to prevent memory allocation
	b := make([]byte, l)

	// record the start index
	cur := 0
	for i := range rNodes {
    
    
		ins := rNodes[i].Ins
		bAddr := utils.StringToSliceByte(ins.Address().String())
		// assign the first few bits of b to string
		copy(b, bAddr)

		// initialize the last few bits, skipping '#'
		for j := len(bAddr) + 1; j < len(b); j++ {
    
    
			b[j] = 0
		}
		b[len(bAddr)] = '#'

		// len of cur
		len := int(cb.opt.VirtualFactor) * ins.Weight()

		for j := 0; j < len; j++ {
    
    
			k := j
			cnt := 0
			// assign values to b one by one, starting with the last one
			for k > 0 {
    
    
				b[l-1-cnt] = byte(k % 10)
				k /= 10
				cnt++
			}
			// at this point, the index inside ret should be cur + j
			index := cur + j
			log.Println("b: ", b,"cur :", cur,"j :", j,"index :", index)
			ret[index].hash = xxhash.Sum64(b)
			ret[index].RealNode = &rNodes[i]
		}
		cur += len
	}
	sort.Sort(&vNodeType{
    
    s: ret})
	return ret
}

ret[index].hash = xxhash.Sum64(b)

这里计算出hash
在这里插入图片描述

一张图片说明一切:
拼接出:hash的key为: 【ip】【port】【#】【序列号】如: 【0.0.0.0:8802#123】

扫描二维码关注公众号,回复: 14719317 查看本文章

感觉这里的#号好像没什么用,可能是为了方便debug
在这里插入图片描述

最后得到的是一个数组,数组大小依据虚拟节点个数还有实例个数确定,数组依据hash 的大小来确定,排序是为了后面的二分能找到具体的节点。

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_38871173/article/details/122909139