go语言实现高速网络框架

使用go语言来实现一个网络框架,把 tcp 自定义通信协议做成一个模板

我们知道一般的网络协议为了适用大部分人的使用,会封装一些的功能,但是这样会对性能产生一定的影响,所以当我们的项目或者模块在条件允许的情况下,我们可以自己定义协议来实现高速的网络框架。

创建项目

创建一个文件夹,fast-web。新建protocol.go, request.go ,response.go, client.go, server.go。

我们的协议可以定为这种形式:

//请求
Request:     //命令        //参数的个数   //参数的长度
 version    command    argsLength    {argLength    arg}
  1byte      1byte       4byte          4byte    unknown
  
  
  
//响应
Response:
 version    reply    bodyLength    {body}
 1byte     1byte      4byte      unknown

复制代码

protocol.go


const (
   ProtocolVersion        = byte(1) // 协议版本号
   headerLengthInProtocol = 6       // 协议中头部占用的字节数
   argsLengthInProtocol   = 4       // 协议中参数个数占用的字节数
   argLengthInProtocol    = 4       // 协议中参数长度占用的字节数
   bodyLengthInProtocol   = 4       // 协议体长度占用的字节数
)

var (
   // 协议版本不匹配错误,如果客户端和服务端的版本不一样就会返回这个错误
   ProtocolVersionMismatchErr = errors.New("protocol version between client and server doesn't match")
)

复制代码

request.go

// 从 reader 中读取请求,并解析出命令和参数。
func readRequestFrom(reader io.Reader) (command byte, args [][]byte, err error) {

   // 读取头部,指定具体的大小,使用 ReadFull 读取满指定字节数据,如果数据还没传输过来,这个方法会进行等待
   header := make([]byte, headerLengthInProtocol)
   _, err = io.ReadFull(reader, header)
   if err != nil {
      return 0, nil, err
   }

   // 头部的第一个字节是协议版本号,拿出来判断协议版本号是否一致
   version := header[0]
   if version != ProtocolVersion {
      return 0, nil, ProtocolVersionMismatchErr
   }

   // 头部的第二个字节是命令,后面的四个字节是参数个数
   command = header[1]
   header = header[2:]

   // 所有的整数到字节数组的转换使用大端形式,所以这里使用 BigEndian 来将头部后四个字节转换为一个 uint32 数字
   argsLength := binary.BigEndian.Uint32(header)
   args = make([][]byte, argsLength)
   if argsLength > 0 {
      // 读取参数长度,同理使用大端处理,并一次性读取满参数
      argLength := make([]byte, argLengthInProtocol)
      for i := uint32(0); i < argsLength; i++ {
         _, err = io.ReadFull(reader, argLength)
         if err != nil {
            return 0, nil, err
         }

         arg := make([]byte, binary.BigEndian.Uint32(argLength))
         _, err = io.ReadFull(reader, arg)
         if err != nil {
            return 0, nil, err
         }
         args[i] = arg
      }
   }
   return command, args, nil
}

// 将请求写入到 writer 中。
func writeRequestTo(writer io.Writer, command byte, args [][]byte) (int, error) {

   // 创建一个缓存区,并将协议版本号、命令和参数个数等写入缓存区
   request := make([]byte, headerLengthInProtocol)
   request[0] = ProtocolVersion
   request[1] = command
   binary.BigEndian.PutUint32(request[2:], uint32(len(args)))

   if len(args) > 0 {
      // 将参数都添加到缓存区
      argLength := make([]byte, argLengthInProtocol)
      for _, arg := range args {
         binary.BigEndian.PutUint32(argLength, uint32(len(arg)))
         request = append(request, argLength...)
         request = append(request, arg...)
      }
   }
   return writer.Write(request)
}
复制代码

response.go

const (
   SuccessReply = 0 // 成功的答复码
   ErrorReply   = 1 // 发生错误的答复码
)

// 从 reader 中读取数据并解析出响应内容。
func readResponseFrom(reader io.Reader) (reply byte, body []byte, err error) {

   // 读取指定字节数据
   header := make([]byte, headerLengthInProtocol)
   _, err = io.ReadFull(reader, header)
   if err != nil {
      return ErrorReply, nil, err
   }

   // 头部的第一个字节是协议版本号,如果版本号不一致很可能解析不成功,所以需要检查
   // 实际上这边可以做一个降级处理,就是尝试以响应的版本号去解析
   version := header[0]
   if version != ProtocolVersion {
      return ErrorReply, nil, errors.New("response " + ProtocolVersionMismatchErr.Error())
   }

   // 从头部解析出答复码还有响应体长度,同理,使用大端解析数字
   reply = header[1]
   header = header[2:]
   body = make([]byte, binary.BigEndian.Uint32(header))
   _, err = io.ReadFull(reader, body)
   if err != nil {
      return ErrorReply, nil, err
   }
   return reply, body, nil
}

// 将响应写入到 writer。
func writeResponseTo(writer io.Writer, reply byte, body []byte) (int, error) {

   // 将响应体相关数据写入响应缓存区,并发送
   bodyLengthBytes := make([]byte, bodyLengthInProtocol)
   binary.BigEndian.PutUint32(bodyLengthBytes, uint32(len(body)))

   response := make([]byte, 2, headerLengthInProtocol+len(body))
   response[0] = ProtocolVersion
   response[1] = reply
   response = append(response, bodyLengthBytes...)
   response = append(response, body...)
   return writer.Write(response)
}

// 向 writer 写入错误信息为 msg 的响应。
func writeErrorResponseTo(writer io.Writer, msg string) (int, error) {
   return writeResponseTo(writer, ErrorReply, []byte(msg))
}
复制代码

client.go

// 客户端结构。
type Client struct {

   // 和服务端建立的连接。
   conn   net.Conn

   // 通往服务端的读取器。
   reader io.Reader
}

// 创建新的客户端。
func NewClient(network string, address string) (*Client, error) {

   // 和服务端建立连接
   conn, err := net.Dial(network, address)
   if err != nil {
      return nil, err
   }
   return &Client{
      conn:   conn,
      reader: bufio.NewReader(conn),
   }, nil
}

// 执行命令。
func (c *Client) Do(command byte, args [][]byte) (body []byte, err error) {

   // 包装请求然后发送给服务端
   _, err = writeRequestTo(c.conn, command, args)
   if err != nil {
      return nil, err
   }

   // 读取服务端返回的响应
   reply, body, err := readResponseFrom(c.reader)
   if err != nil {
      return body, err
   }

   // 如果是错误答复码,将内容包装成 error 并返回
   if reply == ErrorReply {
      return body, errors.New(string(body))
   }
   return body, nil
}

// 关闭客户端。
func (c *Client) Close() error {
   return c.conn.Close()
}
复制代码

server.go

var (
   // 找不到对应的命令处理器错误
   commandHandlerNotFoundErr = errors.New("failed to find a handler of command")
)

// 服务端结构。
type Server struct {

   // 监听器,这个应该大家都很熟悉了吧。
   listener net.Listener

   // 命令处理器,通过命令可以找到对应的处理器。
   handlers map[byte]func(args [][]byte) (body []byte, err error)
}

// 创建新的服务端。
func NewServer() *Server {
   return &Server{
      handlers: map[byte]func(args [][]byte) (body []byte, err error){},
   }
}

// 注册命令处理器。
func (s *Server) RegisterHandler(command byte, handler func(args [][]byte) (body []byte, err error)) {
   s.handlers[command] = handler
}

// 监听并服务于 network 和 address。
func (s *Server) ListenAndServe(network string, address string) (err error) {

   // 监听指定地址
   s.listener, err = net.Listen(network, address)
   if err != nil {
      return err
   }

   // 使用 WaitGroup 记录连接数,并等待所有连接处理完毕
   wg := &sync.WaitGroup{}
   for {
      // 等待客户端连接
      conn, err := s.listener.Accept()
      if err != nil {
         // This error means listener has been closed
         // See src/internal/poll/fd.go@ErrNetClosing
         if strings.Contains(err.Error(), "use of closed network connection") {
            break
         }
         continue
      }

      // 记录连接
      wg.Add(1)
      go func() {
         defer wg.Done()
         s.handleConn(conn)
      }()
   }

   // 等待所有连接处理完毕
   wg.Wait()
   return nil
}

// 处理连接。
func (s *Server) handleConn(conn net.Conn) {

   // 将连接包装成缓冲读取器,提高读取的性能
   reader := bufio.NewReader(conn)
   defer conn.Close()

   for {
      // 读取并解析请求请求
      command, args, err := readRequestFrom(reader)
      if err != nil {
         if err == ProtocolVersionMismatchErr {
            continue
         }
         return
      }

      // 处理请求
      reply, body, err := s.handleRequest(command, args)
      if err != nil {
         writeErrorResponseTo(conn, err.Error())
         continue
      }

      // 发送处理结果的响应
      _, err = writeResponseTo(conn, reply, body)
      if err != nil {
         continue
      }
   }
}

// 处理请求。
func (s *Server) handleRequest(command byte, args [][]byte) (reply byte, body []byte, err error) {

   // 从命令处理器集合中选出对应的处理器
   handle, ok := s.handlers[command]
   if !ok {
      return ErrorReply, nil, commandHandlerNotFoundErr
   }

   // 将处理结果返回
   body, err = handle(args)
   if err != nil {
      return ErrorReply, body, err
   }
   return SuccessReply, body, err
}

// 关闭服务端的方法。
func (s *Server) Close() error {
   if s.listener == nil {
      return nil
   }
   return s.listener.Close()
}
复制代码

注:代码提取自vex框架。

Guess you like

Origin juejin.im/post/7040366330159038501