Netty - 粘包分包以及自定义数据包协议

1.粘包和分包

这里简单介绍一下粘包和分包的概念,比如我们需要传递这串数据give me a coffee give me a tea,最后接收到的数据可能是give me a coffeegive me a tea(粘包现象),也可能是give me
a coffeegive me a tea(分包现象),造成这些现象的原因的主要就是一点:没有一个稳定的数据结构。
我们可以通过一些简单的方法去避免这些问题,比如给数据添加分隔符:
give me a coffee|give me a tea|
但是当数据中本身存在这些分隔符的时候,也会造成错误地分割,这里我常用的是利用数据长度+数据的方式去避免这些问题的(16give me a coffee13give me a tea),有些同学可能会说数据中也有数字不是会造成一样的结果吗?若我们已经读取到后面需要读取的数据的位数,就不会再去考虑后面固定长度数据内部内容,其实已经避免了这个问题的发生。

Server.java
package com.pk.server;

import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;
import org.jboss.netty.handler.codec.string.StringDecoder;
import org.jboss.netty.handler.codec.string.StringEncoder;

import java.net.InetSocketAddress;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * 粘包分包
 * @author hzk
 * @date 2018/10/22
 */
public class Server {
    
    public static void main(String[] args){
        //服务类
        ServerBootstrap serverBootstrap = new ServerBootstrap();

        //boos线程监听端口 worker线程负责数据读写
        ExecutorService boss = Executors.newCachedThreadPool();
        ExecutorService worker = Executors.newCachedThreadPool();

        //设置NioSocket工厂
        serverBootstrap.setFactory(new NioServerSocketChannelFactory(boss,worker));

        //设置管道工厂
        serverBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            @Override
            public ChannelPipeline getPipeline() throws Exception {
                ChannelPipeline channelPipeline = Channels.pipeline();
                //channelPipeline.addLast("decoder",new MyDecoderHandler());
                channelPipeline.addLast("decoder",new StringDecoder());
                channelPipeline.addLast("encoder",new StringEncoder());
                channelPipeline.addLast("handlerOne",new MyHandlerOne());
                return channelPipeline;
            }
        });

        serverBootstrap.bind(new InetSocketAddress(8888));
        System.out.println("Server Start...");
    }
}

MyHandlerOne.java
package com.pk.server;

import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.UpstreamMessageEvent;

/**
 * 粘包分包
 * @author hzk
 * @date 2018/10/22
 */
public class MyHandlerOne extends SimpleChannelHandler{

    private int count = 1;
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
        System.out.println(e.getMessage()+":"+(count++));
    }
}

Client.java
package com.pk.client;

import java.io.IOException;
import java.net.Socket;
import java.nio.ByteBuffer;

/**
 * @author hzk
 * @date 2018/10/22
 */
public class Client {
    
    public static void main(String[] args) throws IOException {
        Socket socket = new Socket("127.0.0.1", 8888);
        String msg = "cliengggg";
        byte[] bytes = msg.getBytes();
        //由于要先缓存byte数组长度 需要一个int类型 所以需要加上4
        ByteBuffer allocate = ByteBuffer.allocate(4+bytes.length);
        allocate.putInt(bytes.length);
        allocate.put(bytes);

        byte[] array = allocate.array();
        for(int i =1;i<1000;i++){
            socket.getOutputStream().write(array);
        }
        socket.close();
    }
}

Server运行结果:
Server Start...
   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	clieng:1
ggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	 	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	:2
cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   		cli	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengg:3
gg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   		cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg:4
   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg   	cliengggg:5
   	cliengggg   	cliengggg   	cliengggg:6

这里可以很明显看出来粘包分包产生的结果, FrameDecoder 这个decoder可以帮助解决粘包分包问题。

MyDecoderHandler.java
package com.pk.server;

import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.frame.FrameDecoder;

/**
 * @author hzk
 * @date 2018/10/22
 */
public class MyDecoderHandler extends FrameDecoder {
    @Override
    protected Object decode(ChannelHandlerContext channelHandlerContext, Channel channel, ChannelBuffer channelBuffer) throws Exception {
        if(channelBuffer.readableBytes()>4){
            
            if(channelBuffer.readableBytes() >2048){
                channelBuffer.skipBytes(channelBuffer.readableBytes());
            }

            //标记
            channelBuffer.markReaderIndex();
            //长度
            int length = channelBuffer.readInt();

            if(channelBuffer.readableBytes() < length){
                channelBuffer.resetReaderIndex();
                //数据包不完整,缓存当前剩余的buffer数据,等待接下来的数据包
                return null;
            }

            //读数据
            byte[] bytes = new byte[length];
            channelBuffer.readBytes(bytes);
            //传递
            return new String(bytes);
        }

        return null;
    }
}

Server运行结果:
Server Start...
cliengggg:1
cliengggg:2
cliengggg:3
cliengggg:4
cliengggg:5
cliengggg:6
cliengggg:7
cliengggg:8
cliengggg:9
cliengggg:10
cliengggg:11
cliengggg:12
cliengggg:13
cliengggg:14

我们将Sever中配置的decoder换成我们自定义的解码器,可以看出来解决了我们的分包粘包问题。

2.自定义数据包协议

在自定义数据包协议之前,我们需要先了解几个点。
Q:消息如何在管道中流转,当前的一个handler如何往下面的一个handler传递一个对象?
A:一个管道中会有多个handler,handler往下传递对象的方法是sendUpstream(event)。
Q:为什么FrameDecoder return的对象就是往下传递的对象?
A:其实就是调用了sendUpstream方法。
Q:buffer里面数据未被读取完怎么办?为什么return null就可以缓存buffer?
A:都是由于cumulation缓存。FrameDecoder里面的cumulation其实就是一个缓存的buffer对象。
Q:socket攻击是什么?
A:把长度定义的很大,这种数据包,通常被称为socket攻击,字节流式攻击。

我们这里给出一个示例去确定handler是否通过sendUpstream传递对象的。

Server.java
package com.pipeline;

import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;

import java.net.InetSocketAddress;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * pipeline 服务端
 * @author hzk
 * @date 2018/10/22
 */
public class Server {
    
    public static void main(String[] args){
        //服务类
        ServerBootstrap serverBootstrap = new ServerBootstrap();

        //boos线程监听端口 worker线程负责数据读写
        ExecutorService boss = Executors.newCachedThreadPool();
        ExecutorService worker = Executors.newCachedThreadPool();

        //设置NioSocket工厂
        serverBootstrap.setFactory(new NioServerSocketChannelFactory(boss,worker));

        //设置管道工厂
        serverBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            @Override
            public ChannelPipeline getPipeline() throws Exception {
                ChannelPipeline channelPipeline = Channels.pipeline();
                channelPipeline.addLast("handlerOne",new MyHandlerOne());
                channelPipeline.addLast("handlerTwo",new MyHandlerTwo());
                return channelPipeline;
            }
        });

        serverBootstrap.bind(new InetSocketAddress(8888));
        System.out.println("Server Start...");
    }
}

MyHandlerOne .java
package com.pipeline;

import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.*;

/**
 * @author hzk
 * @date 2018/10/22
 */
public class MyHandlerOne extends SimpleChannelHandler{

    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
        ChannelBuffer channelBuffer = (ChannelBuffer) e.getMessage();
        byte[] array = channelBuffer.array();
        String msg = new String(array);
        System.out.println("Handler One:"+msg);

        //传递
        ctx.sendUpstream(new UpstreamMessageEvent(ctx.getChannel(),"abc",e.getRemoteAddress()));
        ctx.sendUpstream(new UpstreamMessageEvent(ctx.getChannel(),"efg",e.getRemoteAddress()));
    }
}

MyHandlerTwo .java
package com.pipeline;

import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.channel.MessageEvent;
import org.jboss.netty.channel.SimpleChannelHandler;
import org.jboss.netty.channel.UpstreamMessageEvent;

/**
 * @author hzk
 * @date 2018/10/22
 */
public class MyHandlerTwo extends SimpleChannelHandler{

    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
        String message = (String) e.getMessage();
        System.out.println("Handler Two:"+message);
    }
}

Server运行结果:
Server Start...
Handler One:clientgogogo
Handler Two:abc
Handler Two:efg

通过这个示例我们可以确定handler是通过方法sendUpstream往下传递的,那么了解完了这些以后,我们自己自定义了一个数据包协议。
包头(int-4)+模块号(short-2)+命令号(short-2)+数据长度(int-4)+数据
我们自定义采取的数据结构由这几部分组成,这里贴出项目结构以及代码。

项目结构

在这里插入图片描述

Client.java
package com.ithzk.client;

import com.ithzk.coder.RequestEncoder;
import com.ithzk.coder.ResponseDecoder;
import com.ithzk.constans.Constants;
import com.ithzk.model.Request;
import com.ithzk.module.customspass.request.FightRequest;
import org.jboss.netty.bootstrap.ClientBootstrap;
import org.jboss.netty.channel.*;
import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory;

import java.net.InetSocketAddress;
import java.util.Scanner;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * Netty客户端
 * @author hzk
 * @date 2018/10/8
 */
public class Client {

    public static void main(String[] args) throws InterruptedException {
        //服务类
        ClientBootstrap clientBootstrap = new ClientBootstrap();

        //boss监听端口,worker线程负责数据读写
        ExecutorService boss = Executors.newCachedThreadPool();
        ExecutorService worker = Executors.newCachedThreadPool();

        //设置socket工厂
        clientBootstrap.setFactory(new NioClientSocketChannelFactory());

        //设置管道工厂
        clientBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            @Override
            public ChannelPipeline getPipeline() throws Exception {
                //RequestEncoder -> ResponseDecoder ->ClientHandler
                ChannelPipeline pipeline = Channels.pipeline();
                pipeline.addLast("decoder",new ResponseDecoder());
                pipeline.addLast("encoder",new RequestEncoder());
                pipeline.addLast("clientHandler",new ClientHandler());
                return pipeline;
            }
        });

        //连接服务端
        ChannelFuture channelFuture = clientBootstrap.connect(new InetSocketAddress(Constants.AbstractNettyConfig.ADDRESS, Constants.AbstractNettyConfig.PORT));
        Channel channel = channelFuture.sync().getChannel();

        System.out.println("Client Start...");

        Scanner scanner = new Scanner(System.in);
        while(true){
            System.out.println("Please input:");
            int fightId = Integer.parseInt(scanner.nextLine());
            int count = Integer.parseInt(scanner.nextLine());

            FightRequest fightRequest = new FightRequest(fightId,count);
            Request request = new Request(Constants.AbstractModule.ONE,Constants.AbstractCmd.ONE,fightRequest.getBytes());
            //发送请求
            channel.write(request);
        }
    }
}

ClientHandler.java
package com.ithzk.client;

import com.ithzk.constans.Constants;
import com.ithzk.model.Response;
import com.ithzk.module.customspass.response.FightResponse;
import org.jboss.netty.channel.*;

/**
 * 消息接收处理类
 * @author hzk
 * @date 2018/10/8
 */
public class ClientHandler extends SimpleChannelHandler{

    /**
     * 接收消息
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
        Response response = (Response) e.getMessage();
        if(Constants.AbstractModule.ONE == response.getModule()){
            if(Constants.AbstractCmd.ONE == response.getCmd()){
                FightResponse fightResponse = new FightResponse();
                fightResponse.readFromBytes(response.getData());

                System.out.println("ClientHandler->messageReceived:"+fightResponse);
            }
        }else if(Constants.AbstractModule.TWO == response.getModule()){
            System.out.println("CMD:"+Constants.AbstractCmd.TWO);
        }

    }

    /**
     * 捕获异常
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
        System.out.println("exceptionCaught");
        super.exceptionCaught(ctx, e);
    }

    /**
     * 新连接
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        System.out.println("channelConnected");
        super.channelConnected(ctx, e);
    }

    /**
     * 必须是连接已经建立,关闭通道的时候才会触发
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void channelDisconnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        System.out.println("channelDisconnected");
        super.channelDisconnected(ctx, e);
    }

    /**
     * channel关闭的时候触发
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        System.out.println("channelClosed");
        super.channelClosed(ctx, e);
    }
}

RequestDecoder.java
package com.ithzk.coder;

import com.ithzk.constans.Constants;
import com.ithzk.model.Request;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.frame.FrameDecoder;

/**
 * 请求解码器
 * 数据包格式(根据需求定义)
 * 包头 模块号 命令号 长度 数据 
 * 包头4字节
 * 模块号2字节short
 * 命令号2字节short
 * 长度4字节(描述数据部分字节长度)
 * @author hzk
 * @date 2018/9/29
 */
public class RequestDecoder extends FrameDecoder{

    @Override
    protected Object decode(ChannelHandlerContext channelHandlerContext, Channel channel, ChannelBuffer channelBuffer) throws Exception {
        //可读长度必须大于基本长度
        if(channelBuffer.readableBytes() >= Constants.AbstractDataStructure.DATA_STRUCTURE_LENGTH){
            //防止socket字节流攻击
            if(channelBuffer.readableBytes() > 2048){
                channelBuffer.skipBytes(channelBuffer.readableBytes());
            }

            //记录包头开始偏移Index
            int beginIndex;

            while (true){
                beginIndex = channelBuffer.readerIndex();
                //标记读索引位置
                channelBuffer.markReaderIndex();
                int packHead = channelBuffer.readInt();
                if(Constants.AbstractDataStructure.PACKAGE_HEAD == packHead){
                    break;
                }

                //未读取到包头,还原读索引位置,略过一个字节
                channelBuffer.resetReaderIndex();
                channelBuffer.readByte();

                if(channelBuffer.readableBytes() < Constants.AbstractDataStructure.DATA_STRUCTURE_LENGTH){
                    //数据包不完整,需要等待后面的包来
                    return null;
                }
            }
            //模块号
            short module = channelBuffer.readShort();
            //命令号
            short cmd = channelBuffer.readShort();
            //数据长度
            int length = channelBuffer.readInt();

            //判断请求数据包 数据是否完整
            if(channelBuffer.readableBytes() < length){
                //还原读指针
                channelBuffer.readerIndex(beginIndex);
                return null;
            }

            //读取data数据
            byte[] data = new byte[length];
            channelBuffer.readBytes(data);

            Request request = new Request(module,cmd,data);

            //往下传递
            return request;
        }
        //数据包不完整,需要等待后面的包来
        return null;
    }
}

RequestEncoder.java
package com.ithzk.coder;

import com.ithzk.constans.Constants;
import com.ithzk.model.Request;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.oneone.OneToOneEncoder;

/**
 * 请求编码器
 * 数据包格式(根据需求定义)
 * 包头 模块号 命令号 长度 数据 
 * 包头4字节
 * 模块号2字节short
 * 命令号2字节short
 * 长度4字节(描述数据部分字节长度)
 * @author hzk
 * @date 2018/9/29
 */
public class RequestEncoder extends OneToOneEncoder{

    @Override
    protected Object encode(ChannelHandlerContext channelHandlerContext, Channel channel, Object rs) throws Exception {
        Request request = (Request) rs;
        ChannelBuffer channelBuffer = ChannelBuffers.dynamicBuffer();
        //包头
        channelBuffer.writeInt(Constants.AbstractDataStructure.PACKAGE_HEAD);
        //模块Module
        channelBuffer.writeShort(request.getModule());
        //命令号cmd
        channelBuffer.writeShort(request.getCmd());
        //数据长度
        channelBuffer.writeInt(request.getDataLength());
        //数据
        if(null != request.getData()){
            channelBuffer.writeBytes(request.getData());
        }
        return channelBuffer;
    }
}

ResponseDecoder .java
package com.ithzk.coder;

import com.ithzk.constans.Constants;
import com.ithzk.model.Request;
import com.ithzk.model.Response;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.frame.FrameDecoder;

/**
 * 响应解码器
 * 数据包格式(根据需求定义)
 * 包头 模块号 命令号 长度 数据 
 * 包头4字节
 * 模块号2字节short
 * 命令号2字节short
 * 长度4字节(描述数据部分字节长度)
 * @author hzk
 * @date 2018/9/29
 */
public class ResponseDecoder extends FrameDecoder{

    @Override
    protected Object decode(ChannelHandlerContext channelHandlerContext, Channel channel, ChannelBuffer channelBuffer) throws Exception {
        //可读长度必须大于基本长度
        if(channelBuffer.readableBytes() >= Constants.AbstractDataStructure.DATA_STRUCTURE_LENGTH){
            //防止socket字节流攻击
            if(channelBuffer.readableBytes() > 2048){
                channelBuffer.skipBytes(channelBuffer.readableBytes());
            }

            //记录包头开始偏移Index
            int beginIndex = channelBuffer.readerIndex();

            while (true){
                if(Constants.AbstractDataStructure.PACKAGE_HEAD == channelBuffer.readInt()){
                    break;
                }
            }

            //模块号
            short module = channelBuffer.readShort();
            //命令号
            short cmd = channelBuffer.readShort();
            //状态码
            int code = channelBuffer.readInt();
            //数据长度
            int length = channelBuffer.readInt();

            //判断请求数据包 数据是否完整
            if(channelBuffer.readableBytes() < length){
                //还原读指针
                channelBuffer.readerIndex(beginIndex);
                return null;
            }

            //读取data数据
            byte[] data = new byte[length];
            channelBuffer.readBytes(data);

            Response response = new Response(module,cmd,data,code);

            //往下传递
            return response;
        }

        //数据包不完整,需要等待后面的包来
        return null;
    }

}

ResponseEncoder .java
package com.ithzk.coder;

import com.ithzk.constans.Constants;
import com.ithzk.model.Response;
import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;
import org.jboss.netty.channel.Channel;
import org.jboss.netty.channel.ChannelHandlerContext;
import org.jboss.netty.handler.codec.oneone.OneToOneEncoder;

/**
 * 响应编码器
 * 数据包格式(根据需求定义)
 * 包头 模块号 命令号 长度 数据 
 * </pre>
 * 包头4字节
 * 模块号2字节short
 * 命令号2字节short
 * 长度4字节(描述数据部分字节长度)
 * @author hzk
 * @date 2018/9/29
 */
public class ResponseEncoder extends OneToOneEncoder{

    @Override
    protected Object encode(ChannelHandlerContext channelHandlerContext, Channel channel, Object rs) throws Exception {
        Response response = (Response) rs;
        ChannelBuffer channelBuffer = ChannelBuffers.dynamicBuffer();
        //包头
        channelBuffer.writeInt(Constants.AbstractDataStructure.PACKAGE_HEAD);
        //模块Module
        channelBuffer.writeShort(response.getModule());
        //命令号cmd
        channelBuffer.writeShort(response.getCmd());
        //状态
        channelBuffer.writeInt(response.getCode());
        //数据长度
        channelBuffer.writeInt(response.getDataLength());
        //数据
        if(null != response.getData()){
            channelBuffer.writeBytes(response.getData());
        }
        return channelBuffer;
    }
}

Constants.java
package com.ithzk.constans;

/**
 * 常量
 * @author hzk
 * @date 2018/9/29
 */
public class Constants {

    /**
     * netty配置相关
     */
    public abstract static class AbstractNettyConfig{
        /**
         * 端口
         */
        public static final int PORT = 8888;
        /**
         * IP
         */
        public static final String ADDRESS = "127.0.0.1";
    }

    /**
     * 自定义数据结构相关
     */
    public abstract static class AbstractDataStructure{
        /**
         * 包头
         */
        public static final int PACKAGE_HEAD = -37593513;
        /**
         * 数据包基本长度
         * 4 + 2 +2 + 4
         */
        public static final int DATA_STRUCTURE_LENGTH = 12;
    }

    /**
     * 响应码相关
     */
    public abstract static class AbstractStateCode{
        /**
         * 失败
         */
        public static final int FAILURE = 0;

        /**
         * 成功
         */
        public static final int SUCCESS = 1;
    }

    /**
     * 模板相关(关卡)
     */
    public abstract static class AbstractModule{
        /**
         * 模拟第一关对应序列
         */
        public static final short ONE = 1;
        /**
         * 模拟第二关对应序列
         */
        public static final short TWO = 2;
    }

    /**
     * 命令号相关
     */
    public abstract static class AbstractCmd{
        /**
         * 命令号-1
         */
        public static final short ONE = 1;
        /**
         * 命令号-2
         */
        public static final short TWO = 2;

    }
}

BufferFactory.java
package com.ithzk.core;


import org.jboss.netty.buffer.ChannelBuffer;
import org.jboss.netty.buffer.ChannelBuffers;

import java.nio.ByteOrder;

/**
 * ChannelBuffers工具类
 * @author hzk
 * @date 2018/9/26
 */
public class BufferFactory{

    public static ByteOrder BYTE_ORDER = ByteOrder.BIG_ENDIAN;

    /**
     * 获取一个ChannelBuffer
     * @return
     */
    public static ChannelBuffer getBuffer(){
        ChannelBuffer channelBuffer = ChannelBuffers.dynamicBuffer();
        return channelBuffer;
    }

    /**
     * 获取一个ChannelBuffer 并写入数据
     * @param bytes
     * @return
     */
    public static ChannelBuffer getBuffer(byte[] bytes){
        ChannelBuffer channelBuffer = ChannelBuffers.copiedBuffer(bytes);
        return channelBuffer;
    }


}

Serializable.java
package com.ithzk.core;

import org.jboss.netty.buffer.ChannelBuffer;

import java.nio.charset.Charset;
import java.util.*;

/**
 * 自定义序列化
 * @author hzk
 * @date 2018/9/26
 */
public abstract class Serializable {

    public static final Charset CHARSET = Charset.forName("UTF-8");

    protected ChannelBuffer writeBuffer;

    protected ChannelBuffer readBuffer;

    /**
     * 反序列化具体实现
     */
    protected abstract void read();

    /**
     * 序列化具体实现
     */
    protected abstract void write();

    /**
     * 从bytes数组读取数据
     * @param bytes
     * @return
     */
    public Serializable readFromBytes(byte[] bytes){
        readBuffer = BufferFactory.getBuffer(bytes);
        read();
        readBuffer.clear();
        return this;
    }

    /**
     * 从channelBuffer读取数据
     * @param channelBuffer
     */
    public void readFromBuffer(ChannelBuffer channelBuffer){
        this.readBuffer = channelBuffer;
        read();
    }

    /**
     * 写入到本地channelBuffer
     * @return
     */
    public ChannelBuffer writeToLocalBuffer(){
        this.writeBuffer = BufferFactory.getBuffer();
        write();
        return writeBuffer;
    }

    /**
     * 写入到目标channelBuffer
     * @param channelBuffer
     * @return
     */
    public ChannelBuffer writeToTargetBuffer(ChannelBuffer channelBuffer){
        this.writeBuffer = channelBuffer;
        write();
        return writeBuffer;
    }

    public Serializable writeByte(Byte value){
        writeBuffer.writeByte(value);
        return this;
    }

    public Serializable writeInt(int value){
        writeBuffer.writeInt(value);
        return this;
    }

    public Serializable writeShort(short value){
        writeBuffer.writeShort(value);
        return this;
    }

    public Serializable writeLong(long value){
        writeBuffer.writeLong(value);
        return this;
    }

    public Serializable writeFloat(float value){
        writeBuffer.writeFloat(value);
        return this;
    }

    public Serializable writeDouble(double value){
        writeBuffer.writeDouble(value);
        return this;
    }

    public Serializable writeString(String value){
        if(null == value || value.isEmpty()){
            writeShort((short)0);
            return this;
        }
        byte[] bytes = value.getBytes(CHARSET);
        short size = (short) bytes.length;
        writeBuffer.writeShort(size);
        writeBuffer.writeBytes(bytes);
        return this;
    }

    public Serializable writeObject(Object object){
        if(null == object){
            writeByte((byte)0);
        }else{
            if(object instanceof Integer){
                writeInt((int)object);
            }else if(object instanceof Short){
                writeShort((short)object);
            }else if(object instanceof Byte){
                writeByte((byte)object);
            }else if(object instanceof Long){
                writeLong((long)object);
            }else if(object instanceof Float){
                writeFloat((float)object);
            }else if(object instanceof Double){
                writeDouble((double)object);
            }else if(object instanceof String){
                writeString((String) object);
            }else if(object instanceof Serializable){
                writeByte((byte)1);
                Serializable serializable = (Serializable) object;
                serializable.writeToTargetBuffer(writeBuffer);
            }else{
                throw new RuntimeException("不可序列化类型:[%s]"+object.getClass());
            }
        }
        return this;
    }

    public <T> Serializable writeList(List<T> list){
        if(isEmpty(list)){
            writeBuffer.writeShort((short)0);
            return this;
        }
        writeBuffer.writeShort((short)list.size());
        for(T t:list){
            writeObject(t);
        }
        return this;
    }

    public <K,V> Serializable writeMap(Map<K,V> map){
        if(isEmpty(map)){
            writeBuffer.writeShort((short)0);
            return this;
        }
        writeBuffer.writeShort((short)map.size());
        for (Map.Entry<K,V> entry:map.entrySet()) {
            writeObject(entry.getKey());
            writeObject(entry.getValue());
        }
        return this;
    }


    /**
     * 返回byte数组
     * @return
     */
    public byte[] getBytes(){
        writeToLocalBuffer();
        byte[] bytes = null;
        if(writeBuffer.writerIndex() == 0){
            bytes = new byte[0];
        }else{
            bytes = new byte[writeBuffer.writerIndex()];
            writeBuffer.readBytes(bytes);
        }
        writeBuffer.clear();
        return bytes;
    }

    public byte readByte(){
        return readBuffer.readByte();
    }

    public short readShort(){
        return readBuffer.readShort();
    }

    public int readInt(){
        return readBuffer.readInt();
    }

    public long readLong(){
        return readBuffer.readLong();
    }

    public float readFloat(){
        return readBuffer.readFloat();
    }

    public double readDouble(){
        return readBuffer.readDouble();
    }

    public String readString(){
        short size = readBuffer.readShort();
        if(size <= 0){
            return "";
        }
        byte[] bytes = new byte[size];
        readBuffer.readBytes(bytes);
        return new String(bytes,CHARSET);
    }

    public <K> K readObject(Class<K> clz){
        Object k = null;
        if(clz == int.class || clz == Integer.class){
            k = readInt();
        }else if(clz == byte.class || clz == Byte.class){
            k = readByte();
        }else if(clz == short.class || clz == Short.class){
            k = readShort();
        }else if(clz == long.class || clz == Long.class){
            k = readLong();
        }else if(clz == float.class || clz == Float.class){
            k = readFloat();
        }else if(clz == double.class || clz == Double.class){
            k = readDouble();
        }else if(clz == String.class){
            k = readString();
        }else if(Serializable.class.isAssignableFrom(clz)){
            try {
                byte hasObject = readBuffer.readByte();
                if(hasObject == 1){
                    Serializable temp = (Serializable) clz.newInstance();
                    temp.readFromBuffer(readBuffer);
                    k = temp;
                }else{
                    k = null;
                }
            }catch (Exception e){
                e.printStackTrace();
            }
        }else{
            throw new RuntimeException(String.format("不支持类型:[%s]",clz));
        }
        return (K)k;
    }

    public <T> List<T> readList(Class<T> clz){
        ArrayList<T> list = new ArrayList<>();
        short size = readBuffer.readShort();
        for(int i=0;i<size;i++){
            list.add(readObject(clz));
        }
        return list;
    }

    public <K,V> Map<K,V> readMap(Class<K> keyClz,Class<V> valueClz){
        HashMap<K, V> map = new HashMap<>();
        short size = readBuffer.readShort();
        for (int i =0;i<size;i++){
            K key = readObject(keyClz);
            V value = readObject(valueClz);
            map.put(key,value);
        }
        return map;
    }

    private <T> boolean isEmpty(Collection<T> c) {
        return c == null || c.isEmpty();
    }
    public <K,V> boolean isEmpty(Map<K,V> c) {
        return c == null || c.isEmpty();
    }
}

Request.java
package com.ithzk.model;

/**
 * 请求对象
 * @author hzk
 * @date 2018/9/29
 */
public class Request {

    /**
     * 请求模块
     */
    private short module;

    /**
     * 请求命令号
     */
    private short cmd;

    /**
     * 数据部分
     */
    private byte[] data;

    public Request(short module, short cmd, byte[] data) {
        this.module = module;
        this.cmd = cmd;
        this.data = data;
    }

    public short getModule() {
        return module;
    }

    public void setModule(short module) {
        this.module = module;
    }

    public short getCmd() {
        return cmd;
    }

    public void setCmd(short cmd) {
        this.cmd = cmd;
    }

    public byte[] getData() {
        return data;
    }

    public void setData(byte[] data) {
        this.data = data;
    }

    public int getDataLength(){
        if(null == data){
            return 0;
        }
        return data.length;
    }

}

Response.java
package com.ithzk.model;

/**
 * 响应对象
 * @author hzk
 * @date 2018/9/29
 */
public class Response {

    /**
     * 请求模块
     */
    private short module;

    /**
     * 请求命令号
     */
    private short cmd;

    /**
     * 数据部分
     */
    private byte[] data;

    /**
     * 状态码
     */
    private int code;

    public Response(short module, short cmd, byte[] data, int code) {
        this.module = module;
        this.cmd = cmd;
        this.data = data;
        this.code = code;
    }

    public short getModule() {
        return module;
    }

    public void setModule(short module) {
        this.module = module;
    }

    public short getCmd() {
        return cmd;
    }

    public void setCmd(short cmd) {
        this.cmd = cmd;
    }

    public byte[] getData() {
        return data;
    }

    public void setData(byte[] data) {
        this.data = data;
    }

    public int getCode() {
        return code;
    }

    public void setCode(int code) {
        this.code = code;
    }

    public int getDataLength(){
        if(null == data){
            return 0;
        }
        return data.length;
    }
}

FightRequest.java
package com.ithzk.module.customspass.request;

import com.ithzk.core.Serializable;

/**
 * @author hzk
 * @date 2018/10/8
 */
public class FightRequest extends Serializable{

    /**
     * 战斗ID
     */
    private int fightId;

    /**
     * 次数
     */
    private int count;

    public int getFightId() {
        return fightId;
    }

    public void setFightId(int fightId) {
        this.fightId = fightId;
    }

    public int getCount() {
        return count;
    }

    public void setCount(int count) {
        this.count = count;
    }

    public FightRequest(int fightId, int count) {
        this.fightId = fightId;
        this.count = count;
    }

    public FightRequest() {
    }

    @Override
    protected void read() {
        fightId = readInt();
        count = readInt();
    }

    @Override
    protected void write() {
        writeInt(fightId);
        writeInt(count);
    }

    @Override
    public String toString() {
        return "FightRequest{" +
                "fightId=" + fightId +
                ", count=" + count +
                '}';
    }
}

FightResponse.java
package com.ithzk.module.customspass.response;

import com.ithzk.core.Serializable;

/**
 * @author hzk
 * @date 2018/10/8
 */
public class FightResponse extends Serializable{

    /**
     * 获取金币
     */
    private double gold;

    public double getGold() {
        return gold;
    }

    public void setGold(double gold) {
        this.gold = gold;
    }

    @Override
    protected void read() {
        gold = readDouble();
    }

    @Override
    protected void write() {
        writeDouble(gold);
    }

    @Override
    public String toString() {
        return "FightResponse{" +
                "gold=" + gold +
                '}';
    }
}


Server.java
package com.ithzk.server;

import com.ithzk.coder.RequestDecoder;
import com.ithzk.coder.RequestEncoder;
import com.ithzk.coder.ResponseEncoder;
import com.ithzk.constans.Constants;
import org.jboss.netty.bootstrap.ServerBootstrap;
import org.jboss.netty.channel.ChannelPipeline;
import org.jboss.netty.channel.ChannelPipelineFactory;
import org.jboss.netty.channel.Channels;
import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory;

import java.net.InetSocketAddress;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

/**
 * netty服务端
 * @author hzk
 * @date 2018/10/8
 */
public class Server {

    public static void main(String[] args){
        //服务类
        ServerBootstrap serverBootstrap = new ServerBootstrap();

        //boss监听端口,worker线程负责数据读写
        ExecutorService boss = Executors.newCachedThreadPool();
        ExecutorService worker = Executors.newCachedThreadPool();

        //设置socket工厂
        serverBootstrap.setFactory(new NioServerSocketChannelFactory(boss,worker));

        //设置管道的工厂
        serverBootstrap.setPipelineFactory(new ChannelPipelineFactory() {
            @Override
            public ChannelPipeline getPipeline() throws Exception {
                //RequestDecoder -> ServerHandler ->ResponseEncoder
                ChannelPipeline pipeline = Channels.pipeline();
                pipeline.addLast("decoder",new RequestDecoder());
                pipeline.addLast("encoder",new ResponseEncoder());
                pipeline.addLast("serverHandler",new ServerHandler());
                return pipeline;
            }
        });

        serverBootstrap.bind(new InetSocketAddress(Constants.AbstractNettyConfig.PORT));
        System.out.println("Server Start...");
    }
}

ServerHandler.java
package com.ithzk.server;

import com.ithzk.constans.Constants;
import com.ithzk.model.Request;
import com.ithzk.model.Response;
import com.ithzk.module.customspass.request.FightRequest;
import com.ithzk.module.customspass.response.FightResponse;
import org.jboss.netty.channel.*;

/**
 * 消息接收处理类
 * @author hzk
 * @date 2018/10/8
 */
public class ServerHandler extends SimpleChannelHandler {

    /**
     * 接收消息
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) throws Exception {
        Request request = (Request) e.getMessage();
        if(Constants.AbstractModule.ONE == request.getModule()){
            if(Constants.AbstractCmd.ONE == request.getCmd()){
                FightRequest fightRequest = new FightRequest();
                fightRequest.readFromBytes(request.getData());

                System.out.println("ServerHandler->messageReceived:"+fightRequest);

                //回写数据
                FightResponse fightResponse = new FightResponse();
                fightResponse.setGold(88.88D);

                Response response = new Response(Constants.AbstractModule.ONE,Constants.AbstractCmd.ONE,fightResponse.getBytes(),Constants.AbstractStateCode.SUCCESS);
                ctx.getChannel().write(response);
            }else if(Constants.AbstractModule.TWO == request.getModule()){
                System.out.println("Module:"+Constants.AbstractModule.TWO);
            }
        }else if(Constants.AbstractCmd.TWO == request.getCmd()){
            System.out.println("CMD:"+Constants.AbstractCmd.TWO);
        }

    }

    /**
     * 捕获异常
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, ExceptionEvent e) throws Exception {
        System.out.println("exceptionCaught");
        super.exceptionCaught(ctx, e);
    }

    /**
     * 新连接
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void channelConnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        System.out.println("channelConnected");
        super.channelConnected(ctx, e);
    }

    /**
     * 必须是连接已经建立,关闭通道的时候才会触发
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void channelDisconnected(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        System.out.println("channelDisconnected");
        super.channelDisconnected(ctx, e);
    }

    /**
     * channel关闭的时候触发
     * @param ctx
     * @param e
     * @throws Exception
     */
    @Override
    public void channelClosed(ChannelHandlerContext ctx, ChannelStateEvent e) throws Exception {
        System.out.println("channelClosed");
        super.channelClosed(ctx, e);
    }
}

Client运行结果:
Client Start...
channelConnected
Please input:
1
1
Server运行结果:
Server Start...
channelConnected
ServerHandler->messageReceived:FightRequest{fightId=1, count=1}

通过这一组代码我们可以正确地利用自己自定义的数据结构去传输数据,这里主要理解清楚封装的Serializable工具类以及客户端的ClientHandler和服务端的ServerHandler就能大致了解整个流程是如何进行的,这里就是一个简单的自定义数据包协议的实现方式,贴出git地址方便大家查看代码。

Netty_Demo[git]

猜你喜欢

转载自blog.csdn.net/u013985664/article/details/83754802