基于 socket 手写一个 TCP 服务端及客户端

  通过 socket 实现一个 TCP 服务端与客户端,实现通过 TCP 协议进行消息收发。

  关键在 socket 的使用的理解上。

  socket 是对操作系统提供的协议栈的封装,底层调用的是操作系统提供的协议栈。

  当我们调用 ServerSocket 的 accept 方法时,线程阻塞。以 TCP 协议为例,直到网卡接收到一个三次握手的连接请求,网卡向 CPU 发送中断信号,CPU 调用中断处理程序唤醒我们阻塞在 accept 方法上的线程,进行连接处理。

  三次握手的过程是由协议栈完成的,我们在应用层编程无法感知。直到三次握手完成,协议栈将客户端信息与服务端信息封装在一个 Socket 对象中返回,我们通过该Socket 对象完成数据的收发。

  值得注意的是,在连接建立完成前,操作系统会为本次连接在内核空间开辟两个数据缓冲区:发送缓冲区与接收缓冲区。

  我们要做的是监听发送缓冲区是否有数据到达,以及将需要发送的数据写入发送缓冲区。

  至于网卡接收到的数据何时由操作系统拆包并写入接收缓冲区,以及我们写入发送缓冲区的数据何时会被封装为 TCP 报文发送给网卡是操作系统 OS 控制的,这对我们来说是透明的。不同的 OS 对此会有不同的实现,我们不需要关注这些细节(或者说想要关注也没办法介入)。

  本次实现是通过传统的 ServerSocket 建立服务端,并没有使用通道技术。也就是说是 BIO 的实现,当并发量比较大时可以采用 NIO 多路复用技术进行优化。这可以帮助我们节约线程数。

  对 TCP 报文进行分包有多种方式,本次实现使用的是最普适的方式,通过报文头添加报文长度字段进行分包,也就是与 HTTP 协议 Header 中的 Content-Length 相同的方式。

  测试时客户端发送的数据是一个序列化的对象,服务端对其进行反序列化并检查结果。 

  由于牵扯到线程的切换,本次实现并没有对代码结构进行提前设计,仅仅是简单的实现了数据收发功能。经过设计优化的代码将在下篇博客发出。

  服务端:

/**
 * @Author Nxy
 * @Date 2020/3/21 17:16
 * @Description socket 服务端
 */
public class BasicSeverDemo {
    public static void main(String[] args) {
        ServerSocket server = null;
        try {
            server = new ServerSocket(80);
            System.out.println("server start!");
        } catch (IOException e) {
            e.printStackTrace();
            return;
        }
        while (!Thread.currentThread().isInterrupted()) {
            Socket socket;
            BufferedInputStream in;
            BufferedOutputStream out;
            try {
                //阻塞等待连接请求
                socket = server.accept();
                System.out.println("建立连接:" + socket.getInetAddress());
                in = new BufferedInputStream(socket.getInputStream());
                out = new BufferedOutputStream(socket.getOutputStream());
            } catch (IOException e) {
                e.printStackTrace();
                System.out.println("连接建立失败!");
                continue;
            }
            byte[] result;
            try {
                //阻塞等待接收请求数据
                byte[] lengthByte = IOUtil.readBytesFromInputStream(in, 4);
                //本次请求的长度
                int length = ByteBuffer.wrap(lengthByte).getInt();
                System.out.println("from server:" + length);
                //读取指定长度字节
                result = IOUtil.readBytesFromInputStream(in, length);
            } catch (Exception e) {
                e.printStackTrace();
                break;
            }
            //反序列化对象
            Invocation obj = null;
            try {
                ByteArrayInputStream bis = new ByteArrayInputStream(result);
                ObjectInputStream ois = new ObjectInputStream(bis);
                obj = (Invocation) ois.readObject();
                ois.close();
                bis.close();
            } catch (IOException ex) {
                ex.printStackTrace();
            } catch (ClassNotFoundException ex) {
                ex.printStackTrace();
            }
            System.out.println(obj.getInterfaceName() + ":" + obj.getMethodName());
        }
    }

  客户端:

/**
 * @Author Nxy
 * @Date 2020/3/21 17:54
 * @Description socket 客户端
 */
public class BasicClientDemo {
    public static void main(String[] args) {
        Socket socket;
        BufferedOutputStream out;
        BufferedInputStream in;
        try {
            socket = new Socket("127.0.0.1", 80);
            out = new BufferedOutputStream(socket.getOutputStream());
            in = new BufferedInputStream(socket.getInputStream());
        } catch (IOException e) {
            e.printStackTrace();
            return;
        }
        Object[] params = new Object[2];
        Class[] paramTypes = new Class[2];
        Invocation invocation = new Invocation(BasicClientDemo.class.getName(), "main", paramTypes, params);
        byte[] invocationBytes = toByteArray(invocation);
        int length = invocationBytes.length;
        try {
            System.out.println("from client:" + length);
            out.write(ByteBuffer.allocate(4).putInt(length).array());
            out.flush();
            out.write(invocationBytes);
            out.flush();
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            try {
                out.flush();
                out.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

  工具类:

public class IOUtil {
    /**
     * @Author Nxy
     * @Date 2020/3/21 20:21
     * @Param in:输入流,length:读取字节数
     * @Return
     * @Exception
     * @Description 从输入流读取指定长度字节的数据
     */
    public static byte[] readBytesFromInputStream(BufferedInputStream in,
                                                  int length) throws IOException {
        int readSize;
        byte[] bytes = null;
        bytes = new byte[length];
        long length_tmp = length;
        long index = 0;// start from zero
        while ((readSize = in.read(bytes, (int) index, (int) length_tmp)) != -1) {
            length_tmp -= readSize;
            if (length_tmp == 0) {
                break;
            }
            index = index + readSize;
        }
        return bytes;
    }

    public static byte[] toByteArray(Object obj) {
        byte[] bytes = null;
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        try {
            ObjectOutputStream oos = new ObjectOutputStream(bos); oos.writeObject(obj); oos.flush(); bytes = bos.toByteArray(); oos.close(); bos.close(); } catch (IOException ex) { ex.printStackTrace(); } return bytes; }
}

  执行效果,服务端正常接收到数据并成功反序列化为对象:

猜你喜欢

转载自www.cnblogs.com/niuyourou/p/12542209.html