高性能 Netty 之私有栈协议开发

前言

本文继续来讲关于 Netty 建立私有栈协议的开发知识。本文讲解的顺序为:

  1. 什么是私有栈协议?
  2. 私有栈该具备什么功能?
  3. 私有栈的一般通信模型
  4. 私有栈的数据传输格式

什么是私有协议栈?

在通讯协议上,通信协议分为公有协议和私有协议。像我们在前几篇文章写的 Http / WebSocket,都算是公有协议,这些协议都为大众所熟知,并且有公共信赖的组织来制定标准。而私有协议呢,一般是用于公司或组织内部使用,或者是网络或用户接入使用。但是如果是外来的用户接入私有协议后就必须跟着这种非标准协议,才能够互联互通,否则不可能进入现行的网络。

私有栈的功能描述

一般来说,协议栈都需要具备最基础的功能是消息交互服务调用,所以那么基于 Netty 的协议栈可以具备的功能如下:

  1. 提供高性能的异步通信能力
  2. 提供消息的编解码框架,可以实现 POJO 的序列化和反序列化
  3. 提供基于 IP 低值的白名单接入认证机制
  4. 链路的有效性校验机制
  5. 链路的断连重连机制

通信模型

这里的通信模型指的是一个协议接入,传输信息以及断开的过程。

高性能 Netty 之私有栈协议开发

以上为概要过程,下面是具体的详细描述

  1. 客户端发起握手请求,携带有效的身份认证信息
  2. 服务端对客户端的身份进行校验,包括各种有效性以及信息合法性,然后返回握手应答请求
  3. 链路建立成功后,服务端可以给客户端发送业务消息;同时客户端也可以给服务端发送业务消息
  4. 链路建立成功后,客户端和服务端可以互发心跳消息
  5. 最后服务端退出后,关闭连接,客户都感知对方关闭连接后,被动关闭客户都安连接。

传输格式

之前我们学习过基于应用层协议 Http 的时候,我们可以发现它的传输格式由请求行/请求头部/请求数据三大块组成。所以我们制定私有协议的时候,也可以制定类似的格式。

这次我们的传输格式组成为 消息头 以及 消息体。

代码实现

这次由于需要实现一个较为完整的 demo,所以涉及到的类会略多一点。下面会说明这些类的作用:

类说明

系统配置类

高性能 Netty 之私有栈协议开发

实体结构

高性能 Netty 之私有栈协议开发

编解码

高性能 Netty 之私有栈协议开发

服务端和客户端

高性能 Netty 之私有栈协议开发

Maven 依赖

        <dependency>
            <groupId>org.jboss.marshalling</groupId>
            <artifactId>jboss-marshalling</artifactId>
            <version>2.0.9.Final</version>
        </dependency>
        <dependency>
            <groupId>org.jboss.marshalling</groupId>
            <artifactId>jboss-marshalling-serial</artifactId>
            <version>2.0.9.Final</version>
        </dependency>
        <dependency>
            <groupId>io.netty</groupId>
            <artifactId>netty-all</artifactId>
            <version>4.1.51.Final</version>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>1.2.17</version>
        </dependency>
        <dependency>
            <groupId>commons-logging</groupId>
            <artifactId>commons-logging</artifactId>
            <version>1.1.1</version>
        </dependency>
复制代码

系统配置类

MessageType.java

public enum MessageType {
    SERVICE_REQ((byte) 0), SERVICE_RESP((byte) 1), ONE_WAY((byte) 2), LOGIN_REQ(
            (byte) 3), LOGIN_RESP((byte) 4), HEARTBEAT_REQ((byte) 5), HEARTBEAT_RESP(
            (byte) 6);
    private byte value;
    private MessageType(byte value) {
        this.value = value;
    }    public byte value() {
        return this.value;
    }}复制代码

Constant.java

public class Constant {
    public static final String REMOTEIP = "127.0.0.1";
    public static final int PORT = 8080;
    public static final int LOCAL_PORT = 12088;
    public static final String LOCALIP = "127.0.0.1";
}复制代码

实体结构

Header.java

public final class Header {
    private int crcCode = 0xabef0101;
    private int length;     //消息长度
    private long sessionID; //会话ID
    private byte type;      //消息类型
    private byte prority;   //优先级
    private Map<String, Object> attachment = new HashMap();
    //... 省略 getter 和 setter 方法
}
复制代码

Message.java

public class Message {
    private Header header;
    private Object body;
        //... 省略 getter 和 setter 方法 
}
复制代码

编解码

ChannelBufferByteInput.java

import io.netty.buffer.ByteBuf;
import org.jboss.marshalling.ByteInput;import java.io.IOException;/* channel 字节输入实现类 */
class ChannelBufferByteInput implements ByteInput {
    private final ByteBuf buffer;        public ChannelBufferByteInput(ByteBuf buffer) {        this.buffer = buffer;    }    @Override    public void close() throws IOException {
        // nothing to do
    }    @Override    public int available() throws IOException {
        return buffer.readableBytes();
    }    @Override    public int read() throws IOException {
        if (buffer.isReadable()) {
            return buffer.readByte() & 0xff;
        }        return -1;
    }    @Override    public int read(byte[] array) throws IOException {
        return read(array, 0, array.length);
    }    @Override    public int read(byte[] dst, int dstIndex, int length) throws IOException {
        int available = available();
        if (available == 0) {
            return -1;
        }        length = Math.min(available, length);
        buffer.readBytes(dst, dstIndex, length);
        return length;
    }    @Override    public long skip(long bytes) throws IOException {        int readable = buffer.readableBytes();
        if (readable < bytes) {
            bytes = readable;        }        buffer.readerIndex((int) (buffer.readerIndex() + bytes));
        return bytes;
    }}复制代码

ChannelBufferByteOutput.java

import io.netty.buffer.ByteBuf;
import org.jboss.marshalling.ByteOutput;
import java.io.IOException;
/* channel 字节输出实现类 */
class ChannelBufferByteOutput implements ByteOutput {
    private final ByteBuf buffer;
    public ChannelBufferByteOutput(ByteBuf buffer) {
        this.buffer = buffer;
    }    @Override
    public void close() throws IOException {
        // Nothing to do
    }
    @Override
    public void flush() throws IOException {
        // nothing to do
    }
    @Override
    public void write(int b) throws IOException {
        buffer.writeByte(b);
    }
    @Override
    public void write(byte[] bytes) throws IOException {
        buffer.writeBytes(bytes);
    }
    @Override
    public void write(byte[] bytes, int srcIndex, int length) throws IOException {
        buffer.writeBytes(bytes, srcIndex, length);
    }
    /**
     * Return the {@link ByteBuf} which contains the written content
     *
     */
    ByteBuf getBuffer() {
        return buffer;
    }
}
复制代码

MarshallingCodeFactory.java

public final class MarshallingCodecFactory {
    /** 创建Jboss Marshaller */
    protected static Marshaller buildMarshalling() throws IOException {
        final MarshallerFactory marshallerFactory = Marshalling
            .getProvidedMarshallerFactory("serial");
        final MarshallingConfiguration configuration = new MarshallingConfiguration();
        configuration.setVersion(5);
        Marshaller marshaller = marshallerFactory            .createMarshaller(configuration);        return marshaller;
    }    /** 创建Jboss Unmarshaller */
    protected static Unmarshaller buildUnMarshalling() throws IOException {
        final MarshallerFactory marshallerFactory = Marshalling
                    .getProvidedMarshallerFactory("serial");
        final MarshallingConfiguration configuration = new MarshallingConfiguration();
                configuration.setVersion(5);
        final Unmarshaller unmarshaller = marshallerFactory
            .createUnmarshaller(configuration);        return unmarshaller;
    }}复制代码

MarshallingDecoder.java

public class MarshallingDecoder {
    private final Unmarshaller unmarshaller;
    public MarshallingDecoder() throws IOException {
        unmarshaller = MarshallingCodecFactory.buildUnMarshalling();    }    protected Object decode(ByteBuf in) throws Exception {
        int objectSize = in.readInt();
        ByteBuf buf = in.slice(in.readerIndex(), objectSize);        ByteInput input = new ChannelBufferByteInput(buf);
        try {
            unmarshaller.start(input);            Object obj = unmarshaller.readObject();            unmarshaller.finish();            in.readerIndex(in.readerIndex() + objectSize);            return obj;
        } finally {
            unmarshaller.close();        }    }}复制代码

MarshallingEncoder.java

@Sharable
public class MarshallingEncoder {
    private static final byte[] LENGTH_PLACEHOLDER = new byte[4];
    Marshaller marshaller;    public MarshallingEncoder() throws IOException {
        marshaller = MarshallingCodecFactory.buildMarshalling();    }    protected void encode(Object msg, ByteBuf out) throws Exception {
        try {
            // 写入编码信息
            int lengthPos = out.writerIndex();
            out.writeBytes(LENGTH_PLACEHOLDER);
            ChannelBufferByteOutput output = new ChannelBufferByteOutput(out);
            marshaller.start(output);
            marshaller.writeObject(msg);
            marshaller.finish();
            out.setInt(lengthPos, out.writerIndex() - lengthPos - 4);
        } finally {
            marshaller.close();
        }
    }
}
复制代码

MessageDecoder.java

public class MessageDecoder extends LengthFieldBasedFrameDecoder {
    MarshallingDecoder marshallingDecoder;    public MessageDecoder(int maxFrameLength, int lengthFieldOffset,
        int lengthFieldLength) throws IOException {
      super(maxFrameLength, lengthFieldOffset, lengthFieldLength);
      marshallingDecoder = new MarshallingDecoder();
    }    @Override
    protected Object decode(ChannelHandlerContext ctx, ByteBuf in)
        throws Exception {
        ByteBuf frame = (ByteBuf) super.decode(ctx, in);
        if (frame == null) {
            return null;
        }        Message message = new Message();
        Header header = new Header();
        header.setCrcCode(frame.readInt());        header.setLength(frame.readInt());        header.setSessionID(frame.readLong());        header.setType(frame.readByte());        header.setPriority(frame.readByte());        int size = frame.readInt();
        if (size > 0) {
            Map<String, Object> attch = new HashMap<String, Object>(size);
            int keySize = 0;
            byte[] keyArray = null;
            String key = null;
            for (int i = 0; i < size; i++) {
                keySize = frame.readInt();                keyArray = new byte[keySize];
                frame.readBytes(keyArray);                key = new String(keyArray, "UTF-8");
                attch.put(key, marshallingDecoder.decode(frame));            }            keyArray = null;
            key = null;
            header.setAttachment(attch);        }        if (frame.readableBytes() > 4) {
            message.setBody(marshallingDecoder.decode(frame));        }        message.setHeader(header);        return message;
    }}复制代码

MessageEncoder.java

public final class MessageEncoder extends
    MessageToByteEncoder<Message> {
    MarshallingEncoder marshallingEncoder;    public MessageEncoder() throws IOException {
        this.marshallingEncoder = new MarshallingEncoder();
    }    @Override
    protected void encode(ChannelHandlerContext ctx, Message msg,
        ByteBuf sendBuf) throws Exception {
        if (msg == null || msg.getHeader() == null)
            throw new Exception("The encode message is null");
        sendBuf.writeInt((msg.getHeader().getCrcCode()));        sendBuf.writeInt((msg.getHeader().getLength()));        sendBuf.writeLong((msg.getHeader().getSessionID()));        sendBuf.writeByte((msg.getHeader().getType()));        sendBuf.writeByte((msg.getHeader().getPriority()));        sendBuf.writeInt((msg.getHeader().getAttachment().size()));        String key = null;
        byte[] keyArray = null;
        Object value = null;
        for (Map.Entry<String, Object> param : msg.getHeader().getAttachment()
            .entrySet()) {            key = param.getKey();            keyArray = key.getBytes("UTF-8");
            sendBuf.writeInt(keyArray.length);            sendBuf.writeBytes(keyArray);            value = param.getValue();            marshallingEncoder.encode(value, sendBuf);        }        key = null;
        keyArray = null;
        value = null;
        if (msg.getBody() != null) {
            marshallingEncoder.encode(msg.getBody(), sendBuf);        } else
            sendBuf.writeInt(0);
        sendBuf.setInt(4, sendBuf.readableBytes() - 8);
    }}复制代码

服务端和客户端

服务端 Server.java

public class Server {
    private static final Log LOG = LogFactory.getLog(Server.class);
    public void bind() throws Exception {
        // 配置服务端的NIO线程组
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        ServerBootstrap b = new ServerBootstrap();
        b.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class)
            .option(ChannelOption.SO_BACKLOG, 100)
            .handler(new LoggingHandler(LogLevel.INFO))
            .childHandler(new ChannelInitializer<SocketChannel>() {
                @Override
                public void initChannel(SocketChannel ch)
                    throws IOException {
                ch.pipeline().addLast(
                    new MessageDecoder(1024 * 1024, 4, 4));
                ch.pipeline().addLast(new MessageEncoder());
                ch.pipeline().addLast("readTimeoutHandler",
                    new ReadTimeoutHandler(50));
                ch.pipeline().addLast(new LoginAuthRespHandler());
                ch.pipeline().addLast("HeartBeatHandler",
                    new HeartBeatRespHandler());
                }
            });
        // 绑定端口,同步等待成功
        b.bind(Constant.REMOTEIP, Constant.PORT).sync();
        LOG.info("server start ok : "
            + (Constant.REMOTEIP + " : " + Constant.PORT));
    }
    public static void main(String[] args) throws Exception {
        new Server().bind();
    }
}
复制代码

HeartBeatRespHandler.java

public class HeartBeatRespHandler extends ChannelHandlerAdapter {
    private static final Log LOG = LogFactory.getLog(HeartBeatRespHandler.class);
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
        throws Exception {
        Message message = (Message) msg;        // 返回心跳应答消息
        if (message.getHeader() != null
            && message.getHeader().getType() == MessageType.HEARTBEAT_REQ
                .value()) {
            LOG.info("Receive client heart beat message : ---> "
                + message);
            Message heartBeat = buildHeatBeat();
            LOG.info("Send heart beat response message to client : ---> "
                    + heartBeat);
            ctx.writeAndFlush(heartBeat);
        } else
            ctx.fireChannelRead(msg);
    }
    //心跳构造器
    private Message buildHeatBeat() {
        Message message = new Message();
        Header header = new Header();
        header.setType(MessageType.HEARTBEAT_RESP.value());
        message.setHeader(header);
        return message;
    }
}
复制代码

LoginAuthRespHandler.java

public class LoginAuthRespHandler extends ChannelHandlerAdapter {
    private final static Log LOG = LogFactory.getLog(LoginAuthRespHandler.class);    //缓存框架,用于维护是否登录    private Map<String, Boolean> nodeCheck = new ConcurrentHashMap<String, Boolean>();    private String[] whitekList = { "127.0.0.1", "192.168.1.104" };
    @Override    public void channelRead(ChannelHandlerContext ctx, Object msg)        throws Exception {        Message message = (Message) msg;        // 如果是握手请求消息,处理,其它消息透传        if (message.getHeader() != null
            && message.getHeader().getType() == MessageType.LOGIN_REQ                .value()) {            String nodeIndex = ctx.channel().remoteAddress().toString();            Message loginResp = null;            // 重复登陆,拒绝            if (nodeCheck.containsKey(nodeIndex)) {
            loginResp = buildResponse((byte) -1);
            } else {
            InetSocketAddress address = (InetSocketAddress) ctx.channel()                .remoteAddress();            String ip = address.getAddress().getHostAddress();            boolean isOK = false;
            for (String WIP : whitekList) {
                if (WIP.equals(ip)) {
                isOK = true;
                break;
                }            }            loginResp = isOK ? buildResponse((byte) 0)
                : buildResponse((byte) -1);
            if (isOK)
                nodeCheck.put(nodeIndex, true);
            }            LOG.info("The login response is : " + loginResp
                + " body [" + loginResp.getBody() + "]");
            ctx.writeAndFlush(loginResp);        } else {
            ctx.fireChannelRead(msg);        }    }    private Message buildResponse(byte result) {
        Message message = new Message();        Header header = new Header();        header.setType(MessageType.LOGIN_RESP.value());        message.setHeader(header);        message.setBody(result);        return message;
    }    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)        throws Exception {        cause.printStackTrace();        nodeCheck.remove(ctx.channel().remoteAddress().toString());// 删除缓存
        ctx.close();
        ctx.fireExceptionCaught(cause);    }}复制代码

客户端 Client.java

public class Client {
    private static final Log LOG = LogFactory.getLog(Client.class);
    private ScheduledExecutorService executor = Executors
            .newScheduledThreadPool(1);
    EventLoopGroup group = new NioEventLoopGroup();
    public void connect(int port, String host) throws Exception {
        // 配置客户端NIO线程组
        try {
            Bootstrap b = new Bootstrap();
            b.group(group).channel(NioSocketChannel.class)
                    .option(ChannelOption.TCP_NODELAY, true)
                    .handler(new ChannelInitializer<SocketChannel>() {
                        @Override
                        public void initChannel(SocketChannel ch)
                                throws Exception {
                            ch.pipeline().addLast(
                                    new MessageDecoder(1024 * 1024, 4, 4));
                            ch.pipeline().addLast("MessageEncoder",
                                    new MessageEncoder());
                            ch.pipeline().addLast("readTimeoutHandler",
                                    new ReadTimeoutHandler(50));
                            ch.pipeline().addLast("LoginAuthHandler",
                                    new LoginAuthReqHandler());
                            ch.pipeline().addLast("HeartBeatHandler",
                                    new HeartBeatReqHandler());
                        }
                    });
            // 发起异步连接操作
            ChannelFuture future = b.connect(
                    new InetSocketAddress(host, port),
                    new InetSocketAddress(Constant.LOCALIP,
                            Constant.LOCAL_PORT)).sync();
            // 当对应的channel关闭的时候,就会返回对应的channel。
            future.channel().closeFuture().sync();
        } finally {
            // 所有资源释放完成之后,清空资源,再次发起重连操作
            executor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        TimeUnit.SECONDS.sleep(1);
                        try {
                            connect(Constant.PORT, Constant.REMOTEIP);// 发起重连操作
                        } catch (Exception e) {
                            e.printStackTrace();
                        }
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            });
        }
    }
    public static void main(String[] args) throws Exception {
        new NettyClient().connect(Constant.PORT, Constant.REMOTEIP);
    }
}
复制代码

HeartBeatReqHandler.java

public class HeartBeatReqHandler extends ChannelHandlerAdapter {
    private static final Log LOG = LogFactory.getLog(HeartBeatReqHandler.class);
    private volatile ScheduledFuture<?> heartBeat;
    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
            throws Exception {
        Message message = (Message) msg;        // 握手成功,主动发送心跳消息
        if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.LOGIN_RESP
                .value()) {
            heartBeat = ctx.executor().scheduleAtFixedRate(
                    new HeartBeatReqHandler.HeartBeatTask(ctx), 0, 5000,
                    TimeUnit.MILLISECONDS);
        } else if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.HEARTBEAT_RESP
                .value()) {
            LOG.info("Client receive server heart beat message : ---> "
                            + message);
        } else
            ctx.fireChannelRead(msg);
    }
    private class HeartBeatTask implements Runnable {
        private final ChannelHandlerContext ctx;
        public HeartBeatTask(final ChannelHandlerContext ctx) {
            this.ctx = ctx;
        }
        @Override
        public void run() {
            Message heatBeat = buildHeatBeat();
            LOG.info("Client send heart beat messsage to server : ---> "
                            + heatBeat);
            ctx.writeAndFlush(heatBeat);
        }
        private Message buildHeatBeat() {
            Message message = new Message();
            Header header = new Header();
            header.setType(MessageType.HEARTBEAT_REQ.value());
            message.setHeader(header);
            return message;
        }
    }
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
            throws Exception {
        cause.printStackTrace();
        if (heartBeat != null) {
            heartBeat.cancel(true);
            heartBeat = null;
        }
        ctx.fireExceptionCaught(cause);
    }
}
复制代码

LoginAuthReqHandler.java

public class LoginAuthReqHandler extends ChannelHandlerAdapter {
    private static final Log LOG = LogFactory.getLog(LoginAuthReqHandler.class);
    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        ctx.writeAndFlush(buildLoginReq());    }    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg)
            throws Exception {
        Message message = (Message) msg;        // 如果是握手应答消息,需要判断是否认证成功
        if (message.getHeader() != null
                && message.getHeader().getType() == MessageType.LOGIN_RESP
                .value()) {
            byte loginResult = (byte) message.getBody();
            if (loginResult != (byte) 0) {
                // 握手失败,关闭连接
                ctx.close();
            } else {
                LOG.info("Login is ok : " + message);
                ctx.fireChannelRead(msg);
            }
        } else
            ctx.fireChannelRead(msg);
    }
    //构造登录请求
    private Message buildLoginReq() {
        Message message = new Message();
        Header header = new Header();
        header.setType(MessageType.LOGIN_REQ.value());
        message.setHeader(header);
        return message;
    }
    //异常跑错
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause)
            throws Exception {
        ctx.fireExceptionCaught(cause);
    }
}
复制代码

结语

使用 Netty 搭建私有栈的时候,需要考虑很多可靠性方面的功能。例如说,我们在使用 Http 应用层协议的时候,表面看似很简单,其实背地里需要很多措施和功能在支撑着。所以像我们这种私有的协议栈,可能更多需要考虑性能,可用等因素,如链路断连的情况下消息究竟是丢弃还是重发;我们需要更加完善的编解码器;超时操作,自定义定时任务;安全认证等等。

猜你喜欢

转载自blog.csdn.net/bieber007/article/details/108467710