给golang增加websocket模块

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yueguanghaidao/article/details/46334483

    最近打算做一款类似腾讯<<脑力达人>>的h5游戏,之前打算用skynet来做,所以给skynet增加了websocket模块,

https://github.com/Skycrab/skynet_websocket。刚好最近在学习golang,考虑之下打算用golang来实现,说不定过段时间

还能整个golang游戏服务器。之前我一直认为Python是我的真爱,但现在真心喜欢golang,也许这也是弥补我静态语言

的缺失吧,虽然C++/C还算熟悉,但没有工程经验,始终觉得缺少点什么。我相信golang以后会在服务器领域有一席之地,

现在研究也算投资吧,等golang越来越成熟,gc越来越高效,会有很多转投golang的怀抱。

    我始终相信,一门语言一种文化。当我写Python时,我很少会考虑效率,想的更多的是简洁与优雅实现; 但当我写golang时,

时不时会左右比较,在int32与int64之间徘徊,估算本次大概需要多少byte进行内存预分配。。。。在Python中即使你考虑了,

大多也是徒劳,语言本身很多没有提供。语言的文化,让我痴迷。

    算上前一篇写的定时器(http://blog.csdn.net/yueguanghaidao/article/details/46290539)和本篇的websocket,还差不少东西才能组成游戏服务器,慢慢填坑吧。

     有人说,golang的websocket很多,何必造轮子,但自己写的后期好优化,更新方便,造轮子是快速学习的途径,如果时间

允许,多多造轮子,会在中途收获很多。

github地址:https://github.com/Skycrab/code/tree/master/Go/websocket

    首先看看如何使用:

package websocket

import (
	"fmt"
	"net/http"
	"testing"
)

type MyHandler struct {
}

func (wd MyHandler) CheckOrigin(origin, host string) bool {
	return true
}

func (wd MyHandler) OnOpen(ws *Websocket) {
	fmt.Println("OnOpen")
	ws.SendText([]byte("hello world from server"))
}

func (wd MyHandler) OnMessage(ws *Websocket, message []byte) {
	fmt.Println("OnMessage:", string(message), len(message))
}

func (wd MyHandler) OnClose(ws *Websocket, code uint16, reason []byte) {
	fmt.Println("OnClose", code, string(reason))
}

func (wd MyHandler) OnPong(ws *Websocket, data []byte) {
	fmt.Println("OnPong:", string(data))

}

func TestWebsocket(t *testing.T) {
	http.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) {
		fmt.Println("...")
		var opt = Option{MyHandler{}, false}
		ws, err := New(w, r, &opt)
		if err != nil {
			t.Fatal(err.Error())
		}
		ws.Start()
	})
	fmt.Println("server start")
	http.ListenAndServe(":8001", nil)
}

使用方法和之前的类似,都是像tornado websocket执行方式。

MyHandler实现了WsHandler接口,如果你并不关注所有事件,可以继承WsDefaultHandler,WsDefaultHandler为所有的事件

提供了默认实现。

通过Option实现了默认参数功能,第二个参数代表是否mask发送的数据,客户端是需要的,服务端不需要,所以默认为false。

由于暂时没有websocket client的需求,所以没有提供,需要时再添加吧。

    对比一下golang和lua的实现,代码行数并没有增加多少,golang是400行,lua是340行,不得不说golang编码效率的确

赶得上动态语言。在编写golang和lua实现时,我明显感觉到静态语言具有很大优势,lua出错提示不给力,这也是动态语言的

痛处吧。好消息是Python3.5提供了类型检查,我觉得的确是一大利器。

在这里把代码贴一下,方便查看。

package websocket

import (
	"bufio"
	"bytes"
	"crypto/sha1"
	"encoding/base64"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"strings"
)

var (
	ErrUpgrade     = errors.New("Can \"Upgrade\" only to \"WebSocket\"")
	ErrConnection  = errors.New("\"Connection\" must be \"Upgrade\"")
	ErrCrossOrigin = errors.New("Cross origin websockets not allowed")
	ErrSecVersion  = errors.New("HTTP/1.1 Upgrade Required\r\nSec-WebSocket-Version: 13\r\n\r\n")
	ErrSecKey      = errors.New("\"Sec-WebSocket-Key\" must not be  nil")
	ErrHijacker    = errors.New("Not implement http.Hijacker")
)

var (
	ErrReservedBits    = errors.New("Reserved_bits show using undefined extensions")
	ErrFrameOverload   = errors.New("Control frame payload overload")
	ErrFrameFragmented = errors.New("Control frame must not be fragmented")
	ErrInvalidOpcode   = errors.New("Invalid frame opcode")
)

var (
	crlf         = []byte("\r\n")
	challengeKey = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
)

//referer https://github.com/Skycrab/skynet_websocket/blob/master/websocket.lua

type WsHandler interface {
	CheckOrigin(origin, host string) bool
	OnOpen(ws *Websocket)
	OnMessage(ws *Websocket, message []byte)
	OnClose(ws *Websocket, code uint16, reason []byte)
	OnPong(ws *Websocket, data []byte)
}

type WsDefaultHandler struct {
	checkOriginOr bool // 是否校验origin, default true
}

func (wd WsDefaultHandler) CheckOrigin(origin, host string) bool {
	return true
}

func (wd WsDefaultHandler) OnOpen(ws *Websocket) {
}

func (wd WsDefaultHandler) OnMessage(ws *Websocket, message []byte) {
}

func (wd WsDefaultHandler) OnClose(ws *Websocket, code uint16, reason []byte) {
}

func (wd WsDefaultHandler) OnPong(ws *Websocket, data []byte) {

}

type Websocket struct {
	conn             net.Conn
	rw               *bufio.ReadWriter
	handler          WsHandler
	clientTerminated bool
	serverTerminated bool
	maskOutgoing     bool
}

type Option struct {
	Handler      WsHandler // 处理器, default WsDefaultHandler
	MaskOutgoing bool      //发送frame是否mask, default false
}

func challengeResponse(key, protocol string) []byte {
	sha := sha1.New()
	sha.Write([]byte(key))
	sha.Write(challengeKey)
	accept := base64.StdEncoding.EncodeToString(sha.Sum(nil))
	buf := bytes.NewBufferString("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ")
	buf.WriteString(accept)
	buf.Write(crlf)
	if protocol != "" {
		buf.WriteString("Sec-WebSocket-Protocol: ")
		buf.WriteString(protocol)
		buf.Write(crlf)
	}
	buf.Write(crlf)

	return buf.Bytes()
}

func acceptConnection(r *http.Request, h WsHandler) (challenge []byte, err error) {
	//Upgrade header should be present and should be equal to WebSocket
	if strings.ToLower(r.Header.Get("Upgrade")) != "websocket" {
		return nil, ErrUpgrade
	}

	//Connection header should be upgrade. Some proxy servers/load balancers
	// might mess with it.
	if !strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") {
		return nil, ErrConnection
	}

	// Handle WebSocket Origin naming convention differences
	// The difference between version 8 and 13 is that in 8 the
	// client sends a "Sec-Websocket-Origin" header and in 13 it's
	// simply "Origin".
	if r.Header.Get("Sec-Websocket-Version") != "13" {
		return nil, ErrSecVersion
	}

	origin := r.Header.Get("Origin")
	if origin == "" {
		origin = r.Header.Get("Sec-Websocket-Origin")
	}

	if origin != "" && !h.CheckOrigin(origin, r.Header.Get("Host")) {
		return nil, ErrCrossOrigin
	}

	key := r.Header.Get("Sec-Websocket-Key")
	if key == "" {
		return nil, ErrSecKey
	}

	protocol := r.Header.Get("Sec-Websocket-Protocol")
	if protocol != "" {
		idx := strings.IndexByte(protocol, ',')
		if idx != -1 {
			protocol = protocol[:idx]
		}
	}

	return challengeResponse(key, protocol), nil

}

func websocketMask(mask []byte, data []byte) {
	for i := range data {
		data[i] ^= mask[i%4]
	}
}

func New(w http.ResponseWriter, r *http.Request, opt *Option) (*Websocket, error) {

	var h WsHandler
	var maskOutgoing bool
	if opt == nil {
		h = WsDefaultHandler{true}
		maskOutgoing = false
	} else {
		h = opt.Handler
		maskOutgoing = opt.MaskOutgoing
	}

	challenge, err := acceptConnection(r, h)
	if err != nil {
		var code int
		if err == ErrCrossOrigin {
			code = 403
		} else {
			code = 400
		}
		w.WriteHeader(code)
		w.Write([]byte(err.Error()))
		return nil, err
	}
	hj, ok := w.(http.Hijacker)
	if !ok {
		return nil, ErrHijacker
	}

	conn, rw, err := hj.Hijack()

	ws := new(Websocket)
	ws.conn = conn
	ws.rw = rw
	ws.handler = h
	ws.maskOutgoing = maskOutgoing

	if _, err := ws.conn.Write(challenge); err != nil {
		ws.conn.Close()
		return nil, err
	}
	ws.handler.OnOpen(ws)
	return ws, nil
}

func (ws *Websocket) read(buf []byte) error {
	_, err := io.ReadFull(ws.rw, buf)
	return err
}

func (ws *Websocket) SendFrame(fin bool, opcode byte, data []byte) error {
	//max frame header may 14 length
	buf := make([]byte, 0, len(data)+14)
	var finBit, maskBit byte
	if fin {
		finBit = 0x80
	} else {
		finBit = 0
	}

	buf = append(buf, finBit|opcode)
	length := len(data)
	if ws.maskOutgoing {
		maskBit = 0x80
	} else {
		maskBit = 0
	}
	if length < 126 {
		buf = append(buf, byte(length)|maskBit)
	} else if length < 0xFFFF {
		buf = append(buf, 126|maskBit, 0, 0)
		binary.BigEndian.PutUint16(buf[len(buf)-2:], uint16(length))
	} else {
		buf = append(buf, 127|maskBit, 0, 0, 0, 0, 0, 0, 0, 0)
		binary.BigEndian.PutUint64(buf[len(buf)-8:], uint64(length))
	}

	if ws.maskOutgoing {

	}

	buf = append(buf, data...)
	ws.rw.Write(buf)
	return ws.rw.Flush()
}

func (ws *Websocket) SendText(data []byte) error {
	return ws.SendFrame(true, 0x1, data)
}

func (ws *Websocket) SendBinary(data []byte) error {
	return ws.SendFrame(true, 0x2, data)
}

func (ws *Websocket) SendPing(data []byte) error {
	return ws.SendFrame(true, 0x9, data)
}

func (ws *Websocket) SendPong(data []byte) error {
	return ws.SendFrame(true, 0xA, data)
}

func (ws *Websocket) Close(code uint16, reason []byte) {
	if !ws.serverTerminated {
		data := make([]byte, 0, len(reason)+2)
		if code == 0 && reason != nil {
			code = 1000
		}
		if code != 0 {
			data = append(data, 0, 0)
			binary.BigEndian.PutUint16(data, code)
		}
		if reason != nil {
			data = append(data, reason...)
		}
		ws.SendFrame(true, 0x8, data)
		ws.serverTerminated = true
	}
	if ws.clientTerminated {
		ws.conn.Close()
	}

}

func (ws *Websocket) RecvFrame() (final bool, message []byte, err error) { //text 数据报文
	buf := make([]byte, 8, 8)
	err = ws.read(buf[:2])
	if err != nil {
		return
	}
	header, payload := buf[0], buf[1]
	final = header&0x80 != 0
	reservedBits := header&0x70 != 0
	frameOpcode := header & 0xf
	frameOpcodeIsControl := frameOpcode&0x8 != 0

	if reservedBits {
		// client is using as-yet-undefined extensions
		err = ErrReservedBits
		return
	}

	maskFrame := payload&0x80 != 0
	payloadlen := uint64(payload & 0x7f)

	if frameOpcodeIsControl && payloadlen >= 126 {
		err = ErrFrameOverload
		return
	}

	if frameOpcodeIsControl && !final {
		err = ErrFrameFragmented
		return
	}

	//解析frame长度
	var frameLength uint64
	if payloadlen < 126 {
		frameLength = payloadlen
	} else if payloadlen == 126 {
		err = ws.read(buf[:2])
		if err != nil {
			return
		}
		frameLength = uint64(binary.BigEndian.Uint16(buf[:2]))

	} else { //payloadlen == 127
		err = ws.read(buf[:8])
		if err != nil {
			return
		}
		frameLength = binary.BigEndian.Uint64(buf[:8])
	}

	frameMask := make([]byte, 4, 4)
	if maskFrame {
		err = ws.read(frameMask)
		if err != nil {
			return
		}
	}

	// fmt.Println("final_frame:", final, "frame_opcode:", frameOpcode, "mask_frame:", maskFrame, "frame_length:", frameLength)

	message = make([]byte, frameLength, frameLength)
	if frameLength > 0 {
		err = ws.read(message)
		if err != nil {
			return
		}
	}

	if maskFrame && frameLength > 0 {
		websocketMask(frameMask, message)
	}

	if !final {
		return
	} else {
		switch frameOpcode {
		case 0x1: //text
		case 0x2: //binary
		case 0x8: // close
			var code uint16
			var reason []byte
			if frameLength >= 2 {
				code = binary.BigEndian.Uint16(message[:2])
			}
			if frameLength > 2 {
				reason = message[2:]
			}
			message = nil
			ws.clientTerminated = true
			ws.Close(0, nil)
			ws.handler.OnClose(ws, code, reason)
		case 0x9: //ping
			message = nil
			ws.SendPong(nil)
		case 0xA:
			ws.handler.OnPong(ws, message)
			message = nil
		default:
			err = ErrInvalidOpcode
		}
		return
	}

}

func (ws *Websocket) Recv() ([]byte, error) {
	data := make([]byte, 0, 8)
	for {
		final, message, err := ws.RecvFrame()
		if final {
			data = append(data, message...)
			break
		} else {
			data = append(data, message...)
		}
		if err != nil {
			return data, err
		}
	}
	if len(data) > 0 {
		ws.handler.OnMessage(ws, data)
	}
	return data, nil
}

func (ws *Websocket) Start() {
	for {
		_, err := ws.Recv()
		if err != nil {
			ws.conn.Close()
		}
	}

}


















猜你喜欢

转载自blog.csdn.net/yueguanghaidao/article/details/46334483