Java NIO实现WebSocket服务器

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Kurozaki_Kun/article/details/78843783

简介

在HTTP请求中,服务器往往处于被动的一方,通常都是客户端向服务器发送请求时,服务器才会做出响应,服务器并不会主动向客户端推送消息。因此WebSocket API就为此诞生。WebSocket API是HTML5中的一大特色,能够使得建立连接的双方在任意时刻相互推送消息,这意味着不同于HTTP,服务器服务器也可以主动向客户端推送消息了。

关于WebSocket的介绍,可以参考下一篇博文http://blog.csdn.net/zwto1/article/details/52493119#websocket%E5%8E%9F%E7%90%86

WebSocket协议的格式

为了实现一个能与H5的WebSocket API通信的服务器,我们需要先熟悉WebSocket数据包的格式。定的格式。

握手数据包

在一个连接建立以后,建立连接的双方才可以互相推送消息。双方通过握手即可建立一个连接。握手数据包的格式如下:

客户端向服务器发起请求

这里写图片描述

可以见到,客户端请求连接建立的数据包是一个字符串,而且第一行表明这实际上是一个HTTP报文。其中Connection: Upgrade以及Upgrade: websocket两字段就是用来告知服务器这是一个WebSocket握手请求。

服务器还要关心的一个字段是Sec-WebSocket-Key(倒数第二行),其值是一个随机base64字符串,服务器怎么处理该字符串请往下看。

服务器回应请求

这里写图片描述

可以看到HTTP状态码为101,同样,服务端也带有Connection和Upgrade字段来表明这是一个WebSocket数据包。

Sec-WebSocket-Accept字段是对请求报文中Sec-WebSocket-Key字段进行摘要运算的结果。其运算过程如下
1、将Sec-WebSocket-Key字段的值与字符串258EAFA5-E914-47DA-95CA-C5AB0DC85B11拼接。
2、对拼接后的字符串进行sha1运算,得到160位摘要(二进制)。
3、以base64的形式表示得到的摘要。
客户端会进行同样的运算,并且与服务器返回来的字段作对比,如果发现二者不相同,连接就无法建立了。

通信数据帧

通信数据帧的格式如下(参考官方文档https://www.rfc-editor.org/rfc/rfc6455.txt
这里写图片描述
其中各个字段的含义如下
FIN: 1bit,表示这是否为分片的最后一个数据帧。这是考虑到发送的数据有可能被分片的情况,如果存在分片,将此字段置1就表明这是最后一个分片。如果不存在分片,此字段恒为1。因为只有一个分片就一定是最后一个分片。

RSV1, RSV2, RSV3: 各1bit,全0。现在暂时用不上,为了将来可能用于功能拓展保留的字段。

Opcode: 4bits
指出数据的类型,值的解释如下

含义
0x0 附加数据帧
0x1 文本数据帧
0x2 二进制数据帧
0x3-0x7 暂无定义
0x8 关闭连接
0x9 表示ping
0xA 表示pong
0aB-0xF 暂无定义

MASK: 1bit
表明是否对数据进行掩码运算,置1表示使用掩码。从客户端向服务器发送的数据必须使用掩码。

Payload length: 7 bits, 7+16 bits, or 7+64 bits
表明数据的长度。
如果长度在0-125内,这7bits就表示数据的长度;
如果值为126,紧接着后面2字节(16bits)才表示数据的长度;
如果值为127,后面8字节(64bits)表示数据的长度。

Masking-key: 无 或 4 字节
如果掩码字段(MASK)置0,就不需要Masking-key。如果掩码字段为1,这4字节就是Masking-key,用它与数据部分进行异或运算。

Payload Data: 数据部分,长度可变。

关于其他详细说明可以参考官方文档,例如消息分片规则等。

实现一个WebSocket服务器(群聊天室例子)

为了更加深刻的理解这样一个协议,这里没有使用Java已经封装好操作的类库。

基于NIO监听端口

基于NIO中的ServerSocketChannel,实现一个接收并读取Socket内容的服务端套路如下。


public class WebSocketServer {

    private Selector serverSelector;
    private WebSocketListener socketListener;
    private boolean isRunning = true;

    public WebSocketServer(int serverPort, WebSocketListener socketListener) throws IOException {
        //初始化ServerSocketChannel
        ServerSocketChannel serverSocketChannel =
                ServerSocketChannel.open();
        serverSocketChannel.bind(new InetSocketAddress(serverPort));
        serverSocketChannel.configureBlocking(false);

        //创建选择器
        serverSelector = Selector.open();

        //注册ServerSocketChannel的ACCEPT事件至选择器
        serverSocketChannel.register(serverSelector, SelectionKey.OP_ACCEPT);
        this.socketListener = socketListener;
    }

    public void run() throws IOException {
        while (isRunning) {
            int selectCount = serverSelector.select();
            if (selectCount == 0)
                continue;

            Iterator<SelectionKey> iterator = serverSelector.selectedKeys().iterator();
            while (iterator.hasNext()) {
                SelectionKey selectKey = iterator.next();

                if (selectKey.isAcceptable()) {

                    //ACCEPT就绪,此时调用ServerSocketChannel的accept()方法可获得连接的SocketChannel对象,将其READ事件注册到选择器,就可以读取内容了。
                    ServerSocketChannel serverChannel = (ServerSocketChannel) selectKey.channel();
                    SocketChannel acceptSocketChannel = serverChannel.accept();
                  acceptSocketChannel.configureBlocking(false);  //记得设置为非阻塞模式 
                    acceptSocketChannel.register(serverSelector, SelectionKey.OP_READ);

                } else if (selectKey.isReadable()) {
                    //TODO 读取并处理数据
                }
                iterator.remove();
            }
        }
    }
}

会话管理

在本聊天室中,一个WebSocket连接就视为一个会话,也就是一个用户登录。定义ClientSession来管理每一个连接的SocketChannel。

public class ClientSession {
    private SocketChannel socketChannel;
    private String sessionID;

    public ClientSession(SocketChannel channel) {
        this.socketChannel = channel;
        try {
            MessageDigest sha1 = MessageDigest.getInstance("sha1");
            sha1.update(Util.longToByteArray(System.currentTimeMillis()));
            BigInteger bi = new BigInteger(sha1.digest());
            sessionID = bi.toString(16);
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
    }

    public SocketChannel getSocketChannel() {
        return socketChannel;
    }

    public String getSessionID() {
        return sessionID;
    }
}

会话的几种状态

对于聊天室的一个用户从建立连接到释放连接,服务端最关心的无非是四个关键点:会话建立,收到来自客户端的消息,会话关闭,抛出异常。可以将其抽象成接口。

interface WebSocketListener {
    void onOpen(ClientSession session) throws IOException;

    void onMessage(ClientSession session) throws IOException;

    void onException(ClientSession session, Exception ex);

    void onClose(ClientSession session) throws IOException;
}

处理SocketChannel

前面一开始的关于ServerSocketChannel代码中,还没对SocketChannel进行处理,现在来补上。

public class WebSocketServer {

    private Selector serverSelector;
    private WebSocketListener socketListener;
    private boolean isRunning = true;

    //...部分代码省略

    public void run() throws IOException {
        while (isRunning) {
            int selectCount = serverSelector.select();
            if (selectCount == 0)
                continue;

            Iterator<SelectionKey> iterator = serverSelector.selectedKeys().iterator();
            while (iterator.hasNext()) {
                SelectionKey selectKey = iterator.next();

                if (selectKey.isAcceptable()) {
                    //重复代码省略
                } else if (selectKey.isReadable()) {
                    try {
                        SocketChannel socketChannel = (SocketChannel) selectKey.channel();
                        ClientSession session = (ClientSession) selectKey.attachment();  //用前面定义的ClientSession来作为SocketChannel的attach object,方便存储关于SocketChannel的其他信息,容易管理。

                        if (session == null) {
                            //如果SocketChannel还没有被ClientSession绑定,认为这是一个新连接,需要完成握手
                            byte[] byteArray = Util.readByteArray(socketChannel);
                            System.out.println(new String(byteArray));
                            WSProtocol.Header header = WSProtocol.Header.decodeFromString(new String(byteArray));
                            String receiveKey = header.getHeader("Sec-WebSocket-Key");
                            String response = WSProtocol.getHandShakeResponse(receiveKey);
                            socketChannel.write(ByteBuffer.wrap(response.getBytes()));

                            ClientSession newSession = new ClientSession(socketChannel);
                            selectKey.attach(newSession);
                            socketListener.onOpen(newSession);  //会话打开
                        } else {
                            //收到数据,交给上面定义的接口处理
                            socketListener.onMessage(session);
                        }
                    } catch (IOException e) {
                        e.printStackTrace();
                        //出现异常,进行一系列处理
                        selectKey.channel().close();
                        selectKey.cancel();

                        ClientSession attSession = (ClientSession) selectKey.attachment();
                        socketListener.onException(attSession, e);  //抛出异常
                        socketListener.onClose(attSession);  //强制关闭抛出异常的连接
                    }
                }
                iterator.remove();
            }
        }
    }
}
//...省略下部分代码

数据包的处理

数据包的处理放在单独一个类里面

class WSProtocol {

    static class Header {
        private Map<String, String> headers = new HashMap<>();

        String getHeader(String key) {
            return headers.get(key);
        }

        static Header decodeFromString(String headers) {
            Header header = new Header();

            Map<String, String> headerMap = new HashMap<>();
            String[] headerArray = headers.split("\r\n");
            for (String headerLine : headerArray) {
                if (headerLine.contains(":")) {
                    int splitPos = headerLine.indexOf(":");
                    String key = headerLine.substring(0, splitPos);
                    String value = headerLine.substring(splitPos + 1).trim();
                    headerMap.put(key, value);
                }
            }
            header.headers = headerMap;
            return header;
        }
    }

    static String getHandShakeResponse(String receiveKey) {
        String keyOrigin = receiveKey + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
        MessageDigest sha1;
        String accept = null;
        try {
            sha1 = MessageDigest.getInstance("sha1");
            sha1.update(keyOrigin.getBytes());
            accept = new String(Base64.getEncoder().encode(sha1.digest()));
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        String echoHeader = "";
        echoHeader += "HTTP/1.1 101 Switching Protocols\r\n";
        echoHeader += "Upgrade: websocket\r\n";
        echoHeader += "Connection: Upgrade\r\n";
        echoHeader += "Sec-WebSocket-Accept: " + accept + "\r\n";
        echoHeader += "\r\n";

        return echoHeader;
    }
}

另外,还需要实现接口来处理会话的几种状态。

class WebSocketListenerImpl implements WebSocketListener {

    private Map<String, ClientSession> connSessionMap = new HashMap<>();

    @Override
    public void onOpen(ClientSession session) throws IOException {
        connSessionMap.put(session.getSessionID(), session);
        sendBoardCast(session.getSocketChannel().socket().getInetAddress().getHostName() + ":" +
                session.getSocketChannel().socket().getPort() + " Join", session);
        Log.info("session open: " + session.getSessionID());
    }

    @Override
    public void onMessage(ClientSession session) throws IOException {

        SocketChannel socketChannel = session.getSocketChannel();

        byte[] bytesData = Util.readByteArray(socketChannel);

        //opcode为8,对方主动断开连接
        if ((bytesData[0] & 0xf) == 8) {
            throw new IOException("session disconnect.");
        }

        byte payloadLength = (byte) (bytesData[1] & 0x7f);
        byte[] mask = Arrays.copyOfRange(bytesData, 2, 6);
        byte[] payloadData = Arrays.copyOfRange(bytesData, 6, bytesData.length);
        for (int i = 0; i < payloadData.length; i++) {
            payloadData[i] = (byte) (payloadData[i] ^ mask[i % 4]);
        }

        String echoData =
                "[" + session.getSocketChannel().socket().getInetAddress().getHostAddress() + ":" +
                        session.getSocketChannel().socket().getPort() + "]" +
                        (new String(payloadData));

        sendBoardCast(echoData, session);
    }

    @Override
    public void onException(ClientSession session, Exception ex) {
        Log.info("exception catch: " + ex.getMessage());
    }

    @Override
    public void onClose(ClientSession session) throws IOException {
        connSessionMap.remove(session.getSessionID());

        sendBoardCast(session.getSocketChannel().socket().getInetAddress().getHostName() + ":" +
                session.getSocketChannel().socket().getPort() + " Leave", session);

        Log.info("closed sessionId = " + session.getSessionID());
    }

    private void sendBoardCast(String message, ClientSession ownSession) throws IOException {
        Iterator<ClientSession> iterator = connSessionMap.values().iterator();
        while (iterator.hasNext()) {
            ClientSession nextSession = iterator.next();
            if (nextSession == ownSession) {
                continue;
            }
            byte[] boardCastData = new byte[2 + message.getBytes().length];
            boardCastData[0] = (byte) 0x81;
            boardCastData[1] = (byte) message.getBytes().length;
            System.arraycopy(message.getBytes(), 0, boardCastData, 2, message.getBytes().length);

            nextSession.getSocketChannel().write(ByteBuffer.wrap(boardCastData));
        }
    }
}

这里考虑的极其简单,都是数据长度不超过126,且都是文本不分片的情况,有兴趣的可以按照WebSocket的文档将数据操作的过程写完整。

最后写一张简单的测试页面

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Chat</title>
</head>

<script>
    var socket = null;
    var isConn = false;

    socket = new WebSocket('ws://127.0.0.1:8888');
    socket.onerror = function (err) {
        console.log(err);
        addError('连接错误')
    };
    socket.onopen = function () {
        isConn = true;
        addMessage('连接成功');
        console.log('open');
    };
    socket.onmessage = function (event) {
        console.log(event.data);
        addMessage(event.data);
    };
    socket.onclose = function () {
        console.log('close')
    };


    function sendMessage() {
        var sendText = document.getElementById('input_text').value;
        if (!isConn) {
            addError('发送失败');
        } else {
            if (!sendText) {
                addError('不要发送空消息');
            } else {
                addMessage('<label style="font-style: oblique">' + '[我]' + sendText + '</label>');
                socket.send(sendText);
            }
        }
    }

    function addMessage(message) {
        var textShow = document.getElementById('show-message');
        textShow.innerHTML += message + '<br>'
    }

    function addError(error) {
        var textShow = document.getElementById('show-message');
        textShow.innerHTML += '<label style="color: red">[Error] ' + error + '</label><br>'
    }
</script>
<style>
    button {
        border-radius: 5px;
        padding: 8px;
        color: white;
        border: none;
    }

    input {
        border-radius: 4px;
        width: 300px;
        padding: 6px;
        border: 1px solid dodgerblue;
    }

    #show-message {
        border-radius: 4px;
        height: 320px;
        width: 480px;
        border: 1px double darkgray;
        padding: 8px;
    }
</style>
<body>
<div id="show-message">
</div>
<div style="margin-top: 20px">
    <input id="input_text" type="text" placeholder="message">
    <button onclick="sendMessage()" style="background-color: dodgerblue">Send</button>
</div>
</body>
</html>

运行效果

这里写图片描述

猜你喜欢

转载自blog.csdn.net/Kurozaki_Kun/article/details/78843783