dubbo的rpc功能简化实现

在开源中国里面看到一个入门级RPC框架实现的项目,使用Spring + Netty + Protostuff + ZooKeeper实现的。第一眼看到介绍就感觉这是个简化版本的dubbo,虽然只实现了里面rpc的这一块功能感觉也很值得学习一下。


目录


1.介绍

2.使用方式与实现思路

3.部分源码分析


<!----分割线---->


1.介绍


gitee源码地址

作者博客讲解


rpc过程图:



spring:主要是用来依赖注入,通过对象序列化之后使用其代理调用目标接口

netty:简化nio开发,以数据流转过程中添加编码解码器来实现通信协议的开发,这里用到了nio,同时将序列化框架以编码解码器的方式整合进去

protostuff:序列化框架,类型产品很多,主要是因为jdk自带的序列化备受诟病

zk:用来维护一张服务列表,主要是利用其强数据一致性实现服务的动态上下线,用来实现服务的暴露、注册、发现等


整合起来就是目前rpc常见的一种实现方式了,rpc可以基于应用层协议实现也可以基于传输层协议实现,各有好坏。基于tcp实现rpc效率更高,没有http请求那么多冗余信息,但是对很多问题例如握手连接、断线重连、心跳检测等问题需要自己来开发增加开发难度,而http请求结果这儿长时间的发展已经把各种问题都考虑进去了。


2.使用方式与实现思路


使用方式和实现思路和dubbo类似的。


使用方式

服务提供者:
(a)写服务提供者接口,实现服务提供者接口,将服务提供者配置到sprimg中,同时指定一个发布服务的端口。作者添加了注解的支持。
(b)封装了zk客户端实例到bean中,配置到spring中,同时指定zk服务ip地址用作服务注册。
(c)加载spring配置文件,服务提供者跑起来。


服务消费者:
(a)将服务提供者接口配置到spring中。用户调用他,可以远程调用服务提供者
(b)和上面一样,将zk客户端实例封装后配置到spring中,同时指定zk服务的ip地址,消费者从zk服务上面发现服务提供者(就是获取到服务提供者是否上线,跑在那个端口下)
(c)创建代理接口进行远程调用


实现思路

核心就在自定义一个rpc协议利用netty在消费者端发送消息(消息中包含想要调用的方法、类、参数等),在提供者端接收消息本地调用后生成调用结果返回消息(消息中包含了调用状态、error、结果等)。

在消费者端利用jdk的动态代理生成一个提供者接口的代理对象,调用这个代理对象的目标方法(触发了上面的过程)并返回远程调用的结果。


3.部分源码分析


(a)封装zk客户端用来服务注册
首先手动建立永久节点/registry,注册的服务都在这个节点下面建立临时节点。核心就是ServiceRegistry中的createNode方法。

//常量,zk客户端超时时间,临时节点的前后缀等
public interface Constant {

    int ZK_SESSION_TIMEOUT = 5000;

    String ZK_REGISTRY_PATH = "/registry";
    String ZK_DATA_PATH = ZK_REGISTRY_PATH + "/data";
}

//注册中心实现(注入到spring中)
public class ServiceRegistry {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceRegistry.class);

    private CountDownLatch latch = new CountDownLatch(1);

    private String registryAddress;

    public ServiceRegistry(String registryAddress) {
        this.registryAddress = registryAddress;
    }

    public void register(String data) {
        if (data != null) {
            ZooKeeper zk = connectServer();
            if (zk != null) {
                createNode(zk, data);
            }
        }
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;
        try {
            zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getState() == Event.KeeperState.SyncConnected) {
                        latch.countDown();
                    }
                }
            });
            latch.await();
        } catch (IOException | InterruptedException e) {
            LOGGER.error("", e);
        }
        return zk;
    }

    private void createNode(ZooKeeper zk, String data) {
        try {
            byte[] bytes = data.getBytes();
            String path = zk.create(Constant.ZK_DATA_PATH, bytes, ZooDefs.Ids.OPEN_ACL_UNSAFE, CreateMode.EPHEMERAL_SEQUENTIAL);
            LOGGER.debug("create zookeeper node ({} => {})", path, data);
        } catch (KeeperException | InterruptedException e) {
            LOGGER.error("", e);
        }
    }
}

(b)封装zk客户端用来服务发现核心就是watchNode方法用来读取zk下面的临时节点,读取到就证明服务在线且可用。

public class ServiceDiscovery {

    private static final Logger LOGGER = LoggerFactory.getLogger(ServiceDiscovery.class);

    private CountDownLatch latch = new CountDownLatch(1);

    private volatile List<String> dataList = new ArrayList<>();

    private String registryAddress;

    public ServiceDiscovery(String registryAddress) {
        this.registryAddress = registryAddress;

        ZooKeeper zk = connectServer();
        if (zk != null) {
            watchNode(zk);
        }
    }

    public String discover() {
        String data = null;
        int size = dataList.size();
        if (size > 0) {
            if (size == 1) {
                data = dataList.get(0);
                LOGGER.debug("using only data: {}", data);
            } else {
                data = dataList.get(ThreadLocalRandom.current().nextInt(size));
                LOGGER.debug("using random data: {}", data);
            }
        }
        return data;
    }

    private ZooKeeper connectServer() {
        ZooKeeper zk = null;
        try {
            zk = new ZooKeeper(registryAddress, Constant.ZK_SESSION_TIMEOUT, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getState() == Event.KeeperState.SyncConnected) {
                        latch.countDown();
                    }
                }
            });
            latch.await();
        } catch (IOException | InterruptedException e) {
            LOGGER.error("", e);
        }
        return zk;
    }

    private void watchNode(final ZooKeeper zk) {
        try {
            List<String> nodeList = zk.getChildren(Constant.ZK_REGISTRY_PATH, new Watcher() {
                @Override
                public void process(WatchedEvent event) {
                    if (event.getType() == Event.EventType.NodeChildrenChanged) {
                        watchNode(zk);
                    }
                }
            });
            List<String> dataList = new ArrayList<>();
            for (String node : nodeList) {
                byte[] bytes = zk.getData(Constant.ZK_REGISTRY_PATH + "/" + node, false, null);
                dataList.add(new String(bytes));
            }
            LOGGER.debug("node data: {}", dataList);
            this.dataList = dataList;
        } catch (KeeperException | InterruptedException e) {
            LOGGER.error("", e);
        }
    }
}

(c)开发netty

服务的注册于发现主要是用来动态得管理服务,进行服务治理。

正正的rpc功能还是要靠netty来实现。作者的思路其实是利用netty自己基于tcp的基础上封装了一个rpc协议(rpcRequest用来传输序列化对象包括想要远程调用的接口类名、方法名、参数等信息,rpcRespnse用来返回调用状态、调用结果等信息)。

然后自定义rpc协议的编码器、解码器进行消息通信,具体如下:首先从zk上面获取注册的服务地址,在消费者端使用netty发送rpcRequest到服务提供者,服务提供者端利用自定义的rpc解码器解析rpcRequest获取到消费者想要调用的某个方法,然后利用反射进行调用,接着将结果构造成rpcResponse发回给消费者。消费者再次利用netty解析rpcResponse获取到调用结果。

整个过程利用的相关技术包括:自定义rpcRequest类、rpcResponse类,两个解码器,zk客户端读取临时节点,解码编码之间使用SimpleChannelInboundHandler来处理rpc请求(真正利用反射invoke()进行调用就是在这个类里面),序列化过程(序列化主要在消费者端进行,反射调用在提供者端进行,这个时候高效率的序列化框架就用上了)


rpcRequest/rpcResponse:

public class RpcRequest {

    private String requestId;
    private String className;
    private String methodName;
    private Class<?>[] parameterTypes;
    private Object[] parameters;

    // getter/setter...
}

public class RpcResponse {

    private String requestId;
    private Throwable error;
    private Object result;

    // getter/setter...
}

rpcRequest/rpcResponse的编解码器:

public class RpcDecoder extends ByteToMessageDecoder {

    private Class<?> genericClass;

    public RpcDecoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    public void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        if (in.readableBytes() < 4) {
            return;
        }
        in.markReaderIndex();
        int dataLength = in.readInt();
        if (dataLength < 0) {
            ctx.close();
        }
        if (in.readableBytes() < dataLength) {
            in.resetReaderIndex();
            return;
        }
        byte[] data = new byte[dataLength];
        in.readBytes(data);

        Object obj = SerializationUtil.deserialize(data, genericClass);
        out.add(obj);
    }
}


public class RpcEncoder extends MessageToByteEncoder {

    private Class<?> genericClass;

    public RpcEncoder(Class<?> genericClass) {
        this.genericClass = genericClass;
    }

    @Override
    public void encode(ChannelHandlerContext ctx, Object in, ByteBuf out) throws Exception {
        if (genericClass.isInstance(in)) {
            byte[] data = SerializationUtil.serialize(in);
            out.writeInt(data.length);
            out.writeBytes(data);
        }
    }
}


序列化工具类:

整合了序列化框架,当然先用jdk原生的也可以,想用其他的如marshling、protobuf也可以,直接在工具类中修改实现就可以了。

public class SerializationUtil {

    private static Map<Class<?>, Schema<?>> cachedSchema = new ConcurrentHashMap<>();

    private static Objenesis objenesis = new ObjenesisStd(true);

    private SerializationUtil() {
    }

    @SuppressWarnings("unchecked")
    private static <T> Schema<T> getSchema(Class<T> cls) {
        Schema<T> schema = (Schema<T>) cachedSchema.get(cls);
        if (schema == null) {
            schema = RuntimeSchema.createFrom(cls);
            if (schema != null) {
                cachedSchema.put(cls, schema);
            }
        }
        return schema;
    }

    @SuppressWarnings("unchecked")
    public static <T> byte[] serialize(T obj) {
        Class<T> cls = (Class<T>) obj.getClass();
        LinkedBuffer buffer = LinkedBuffer.allocate(LinkedBuffer.DEFAULT_BUFFER_SIZE);
        try {
            Schema<T> schema = getSchema(cls);
            return ProtostuffIOUtil.toByteArray(obj, schema, buffer);
        } catch (Exception e) {
            throw new IllegalStateException(e.getMessage(), e);
        } finally {
            buffer.clear();
        }
    }

    public static <T> T deserialize(byte[] data, Class<T> cls) {
        try {
            T message = (T) objenesis.newInstance(cls);
            Schema<T> schema = getSchema(cls);
            ProtostuffIOUtil.mergeFrom(data, message, schema);
            return message;
        } catch (Exception e) {
            throw new IllegalStateException(e.getMessage(), e);
        }
    }
}

提供者端处理rpc请求的handler:

直接继承netty的SimpleChannelInboundHandler即可。

public class RpcHandler extends SimpleChannelInboundHandler<RpcRequest> {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcHandler.class);

    private final Map<String, Object> handlerMap;

    public RpcHandler(Map<String, Object> handlerMap) {
        this.handlerMap = handlerMap;
    }

    @Override
    public void channelRead0(final ChannelHandlerContext ctx, RpcRequest request) throws Exception {
        RpcResponse response = new RpcResponse();
        response.setRequestId(request.getRequestId());
        try {
            Object result = handle(request);
            response.setResult(result);
        } catch (Throwable t) {
            response.setError(t);
        }
        ctx.writeAndFlush(response).addListener(ChannelFutureListener.CLOSE);
    }

    private Object handle(RpcRequest request) throws Throwable {
        String className = request.getClassName();
        Object serviceBean = handlerMap.get(className);

        Class<?> serviceClass = serviceBean.getClass();
        String methodName = request.getMethodName();
        Class<?>[] parameterTypes = request.getParameterTypes();
        Object[] parameters = request.getParameters();

        /*Method method = serviceClass.getMethod(methodName, parameterTypes);
        method.setAccessible(true);
        return method.invoke(serviceBean, parameters);*/

        FastClass serviceFastClass = FastClass.create(serviceClass);
        FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes);
        return serviceFastMethod.invoke(serviceBean, parameters);
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        LOGGER.error("server caught exception", cause);
        ctx.close();
    }
}


消费者端发送请求的handler:

public class RpcClient extends SimpleChannelInboundHandler<RpcResponse> {

    private static final Logger LOGGER = LoggerFactory.getLogger(RpcClient.class);

    private String host;
    private int port;

    private RpcResponse response;

    private final Object obj = new Object();

    public RpcClient(String host, int port) {
        this.host = host;
        this.port = port;
    }

    @Override
    public void channelRead0(ChannelHandlerContext ctx, RpcResponse response) throws Exception {
        this.response = response;

        synchronized (obj) {
            obj.notifyAll(); // 收到响应,唤醒线程
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        LOGGER.error("client caught exception", cause);
        ctx.close();
    }

    public RpcResponse send(RpcRequest request) throws Exception {
        EventLoopGroup group = new NioEventLoopGroup();
        try {
            Bootstrap bootstrap = new Bootstrap();
            bootstrap.group(group).channel(NioSocketChannel.class)
                .handler(new ChannelInitializer<SocketChannel>() {
                    @Override
                    public void initChannel(SocketChannel channel) throws Exception {
                        channel.pipeline()
                            .addLast(new RpcEncoder(RpcRequest.class)) // 将 RPC 请求进行编码(为了发送请求)
                            .addLast(new RpcDecoder(RpcResponse.class)) // 将 RPC 响应进行解码(为了处理响应)
                            .addLast(RpcClient.this); // 使用 RpcClient 发送 RPC 请求
                    }
                })
                .option(ChannelOption.SO_KEEPALIVE, true);

            ChannelFuture future = bootstrap.connect(host, port).sync();
            future.channel().writeAndFlush(request).sync();

            synchronized (obj) {
                obj.wait(); // 未收到响应,使线程等待
            }

            if (response != null) {
                future.channel().closeFuture().sync();
            }
            return response;
        } finally {
            group.shutdownGracefully();
        }
    }
}

(d)代理开发

消费者想要调用提供者需要对象来调用其方法,消费者由于是远程调用所以使用一个代理对象,主要是根据想要调用的接口生成代理对象,核心在于代理对象的InvocationHandler方法中利用上面的netty发送rpc请求,获取rpc想要然后生成调用结果返回。所以使用代理对象调用目标方法时就得到了远程调用后的结果。这样就营造了一种远程方法像在本地调用一样的效果

public class RpcProxy {

    private String serverAddress;
    private ServiceDiscovery serviceDiscovery;

    public RpcProxy(String serverAddress) {
        this.serverAddress = serverAddress;
    }

    public RpcProxy(ServiceDiscovery serviceDiscovery) {
        this.serviceDiscovery = serviceDiscovery;
    }

    @SuppressWarnings("unchecked")
    public <T> T create(Class<?> interfaceClass) {
        return (T) Proxy.newProxyInstance(
            interfaceClass.getClassLoader(),
            new Class<?>[]{interfaceClass},
            new InvocationHandler() {
                @Override
                public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
                    RpcRequest request = new RpcRequest(); // 创建并初始化 RPC 请求
                    request.setRequestId(UUID.randomUUID().toString());
                    request.setClassName(method.getDeclaringClass().getName());
                    request.setMethodName(method.getName());
                    request.setParameterTypes(method.getParameterTypes());
                    request.setParameters(args);

                    if (serviceDiscovery != null) {
                        serverAddress = serviceDiscovery.discover(); // 发现服务
                    }

                    String[] array = serverAddress.split(":");
                    String host = array[0];
                    int port = Integer.parseInt(array[1]);

                    RpcClient client = new RpcClient(host, port); // 初始化 RPC 客户端
                    RpcResponse response = client.send(request); // 通过 RPC 客户端发送 RPC 请求并获取 RPC 响应

                    if (response.isError()) {
                        throw response.getError();
                    } else {
                        return response.getResult();
                    }
                }
            }
        );
    }
}




猜你喜欢

转载自blog.csdn.net/qq_34448345/article/details/79360603