Netty- HTTP协议服务实现

实现一个简单的Http请求及响应过程:

1、Client向Server发送http请求。

2、Server端对http请求进行解析。

3、Server端向client发送http响应。

4、Client对http响应进行解析。

package rpc.server;

import java.io.IOException;
import java.net.Inet4Address;
import java.net.ServerSocket;
import java.net.UnknownHostException;
import java.util.HashMap;
import java.util.Map;

import org.apache.commons.collections4.MapUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.http.HttpRequestDecoder;
import io.netty.handler.codec.http.HttpResponseEncoder;
import rpc.common.Constant;
import rpc.common.annotation.HttpService;
import rpc.server.exception.RpcInitialException;
import rpc.server.handler.RpcHttpHandler;
import rpc.server.registry.ServiceRegistry;

public class RpcHttpServer implements ApplicationContextAware, InitializingBean {

	private static final Logger LOGGER = LoggerFactory.getLogger(RpcServer.class);

	private String serverAddress;
	private ServiceRegistry serviceRegistry;
	private int port;
	
	private static final Map<String, Object> handlerMap = new HashMap<String, Object>(); 

	
	public RpcHttpServer() {
		this.serverAddress = getHost();
		this.port = getPort();
		LOGGER.debug("RpcHttpServer....server address = http://"+this.serverAddress + ":" + this.port);
	}

	public RpcHttpServer(ServiceRegistry serviceRegistry, int port) {
		this.serverAddress = getHost();
		//读取空闲的可用端口
		this.port = port;
		this.serviceRegistry = serviceRegistry;
		LOGGER.debug("RpcHttpServer....server address = http://"+this.serverAddress + ":" + this.port);
	}
	
	public RpcHttpServer(ServiceRegistry serviceRegistry, String serverAddress, int port) {
		this.serverAddress = serverAddress;
		//读取空闲的可用端口
		this.port = port;
		this.serviceRegistry = serviceRegistry;
		LOGGER.debug("RpcHttpServer....server address = http://"+this.serverAddress + ":" + this.port);
	}

	private static String getHost() {
		try {
			Inet4Address ia = (Inet4Address) Inet4Address.getLocalHost();
			return ia.getHostAddress();
		} catch (UnknownHostException e) {
			LOGGER.error("Rpc 服务启动失败!", e);
			throw new RpcInitialException("Rpc 服务启动失败!", e);
		}
	}
	private static int getPort() {
		for (int i = 9000; i< 10000; i++) {
			try {
				ServerSocket serverSocket =  new ServerSocket(9000);
				serverSocket.close();
				return serverSocket.getLocalPort();
			} catch (IOException e) {
				continue;
			} 
		}
		return 7999;
	}
	@Override
	public void setApplicationContext(ApplicationContext ctx) throws BeansException {
		LOGGER.info("RpcServer.setApplicationContext() -- to set ApplicationContext for RPC SERVER");
		Map<String, Object> serviceBeanMap = ctx.getBeansWithAnnotation(HttpService.class); // get the rpc serice in spring context
		if (MapUtils.isNotEmpty(serviceBeanMap) && serviceBeanMap.values() != null) {
			for (Object serviceBean : serviceBeanMap.values()) {
				String interfaceName = serviceBean.getClass().getAnnotation(HttpService.class).value().getName();
				LOGGER.info("interfaceName = " + interfaceName);
				handlerMap.put(interfaceName, serviceBean);
			}
		} else {
			LOGGER.warn("non http service!");
		}
		LOGGER.info("RpcServer.setApplicationContext() -- to set ApplicationContext for RPC SERVER COMPLETED!");
	}

	@Override
	public void afterPropertiesSet() throws Exception {
		LOGGER.info("RpcServer.afterPropertiesSet() -- begin!");
		Runnable runner = new ChannelStartRunner();
		new Thread(runner).start();
	}
	
	private class ChannelStartRunner implements Runnable {

		@Override
		public void run() {
			EventLoopGroup bossGroup = new NioEventLoopGroup();
			EventLoopGroup workerGroup = new NioEventLoopGroup();
			try {
				ServerBootstrap bootstrap = new ServerBootstrap();
				bootstrap.group(bossGroup, workerGroup).channel(NioServerSocketChannel.class)
						.childHandler(new ChannelInitializer<SocketChannel>() {
							@Override
							public void initChannel(SocketChannel channel) throws Exception {
								channel.pipeline().addLast(new HttpRequestDecoder()) 
										.addLast(new HttpResponseEncoder())
										.addLast(new RpcHttpHandler(handlerMap));
							}
						})
						.option(ChannelOption.SO_BACKLOG, 128)
						.childOption(ChannelOption.SO_KEEPALIVE, true);
				

				ChannelFuture future = bootstrap.bind(port).sync();
				LOGGER.info("RpcServer.afterPropertiesSet() -- service has bind for port:{}!", port);

				if (serviceRegistry != null) {
					LOGGER.info("RpcServer.afterPropertiesSet() -- to register rpc service: {}:{}! ", serverAddress, port);
					serviceRegistry.register(serverAddress + ":" + port, handlerMap.keySet(), Constant.ZK_HTTP_PATH); // register service
				}

				future.channel().closeFuture().sync();
			} catch (Exception e) {
				LOGGER.error("ChannelStartRunner.run() -- ", e);
			} finally {
				workerGroup.shutdownGracefully();
				bossGroup.shutdownGracefully();
			}
			
		}
		
	}
}
package rpc.server.handler;

import java.util.Map;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpHeaders.Values;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import net.sf.cglib.reflect.FastClass;
import net.sf.cglib.reflect.FastMethod;
import rpc.common.RpcRequest;
import rpc.common.RpcResponse;
import rpc.common.SerializationUtil;
import rpc.common.http.ByteBufToBytes;

public class RpcHttpHandler extends ChannelInboundHandlerAdapter {

	private static final Logger LOGGER = LoggerFactory.getLogger(RpcHttpHandler.class);
	private ByteBufToBytes reader;
	private HttpRequest request;
	
	private final Map<String, Object> handlerMap;

	public RpcHttpHandler(Map<String, Object> handlerMap) {
		this.handlerMap = handlerMap;
	}

	/**
	 * (non-Javadoc)
	 * 
	 * @see io.netty.channel.SimpleChannelInboundHandler#channelRead0(io.netty.channel.ChannelHandlerContext,
	 *      java.lang.Object)
	 */
	@Override
	public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception {
		RpcResponse response = new RpcResponse();
		if (msg instanceof HttpRequest) {  
            request = (HttpRequest) msg;  
            if (HttpHeaders.isContentLengthSet(request)) {  
                reader = new ByteBufToBytes((int) HttpHeaders.getContentLength(request));  
            }  
        }  
  
        if (msg instanceof HttpContent) {  
            HttpContent httpContent = (HttpContent) msg;  
            ByteBuf content = httpContent.content();  
            reader.reading(content);  
            content.release();  
  
            if (reader.isEnd()) {  
                RpcRequest request = SerializationUtil.deserialize(reader.readFull(), RpcRequest.class);
                try {
        			LOGGER.info("RpcHttpHandler.channelRead0 deal with id = {}", request.getRequestId());
        			response.setRequestId(request.getRequestId());
        			Object result = handle(request);
        			response.setResult(result);
        			LOGGER.info("RpcHttpHandler.channelRead0 ended with id ={}", request.getRequestId());
        		} catch (Throwable t) {
        			LOGGER.error("RpcHttpHandler.channelRead0 failed of id = {}!", request.getRequestId(), t);
        			response.setError(t);
        		}
                FullHttpResponse httpresponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.wrappedBuffer(SerializationUtil.serialize(response)));  
                httpresponse.headers().set(HttpHeaders.Names.CONTENT_TYPE, "text/plain");  
                httpresponse.headers().set(HttpHeaders.Names.CONTENT_LENGTH, httpresponse.content().readableBytes());  
                httpresponse.headers().set(HttpHeaders.Names.CONNECTION, Values.KEEP_ALIVE);  
                if (HttpHeaders.isKeepAlive(this.request)) {
                	 httpresponse.headers().set(HttpHeaders.Names.CONNECTION, Values.KEEP_ALIVE);  
                }
                ctx.writeAndFlush(httpresponse).addListener(ChannelFutureListener.CLOSE);  
            }  
        }  
	}

	
	@Override
	public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
		ctx.flush();
	}

	/**
	 * 处理消息
	 * 
	 * @param request
	 * @return
	 * @throws Throwable
	 */
	private Object handle(RpcRequest request) throws Throwable {
		LOGGER.info("RpcHandler.handle deal with id =" + request.getRequestId());
		String className = request.getClassName();
		Object serviceBean = handlerMap.get(className);

		Class<?> serviceClass = serviceBean.getClass();
		String methodName = request.getMethodName();
		Class<?>[] parameterTypes = request.getParameterTypes();
		Object[] parameters = request.getParameters();

		/*
		 * Method method = serviceClass.getMethod(methodName, parameterTypes);
		 * method.setAccessible(true); return method.invoke(serviceBean,
		 * parameters);
		 */

		FastClass serviceFastClass = FastClass.create(serviceClass);
		FastMethod serviceFastMethod = serviceFastClass.getMethod(methodName, parameterTypes);
		LOGGER.info("RpcHandler.handle ened with id =" + request.getRequestId());
		return serviceFastMethod.invoke(serviceBean, parameters);
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
		LOGGER.error("server caught exception", cause);
		ctx.close();
	}

}
package rpc.common.annotation;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.springframework.stereotype.Component;

/**
 * use for &lt;context:component-scan>,when spring started the rpc service can be started too
 * @author mrh
 *
 */
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Component
public @interface HttpService {

	Class<?> value();
}
package rpc.common.http;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;

public class ByteBufToBytes {
	private ByteBuf temp;

	private boolean end = true;

	public ByteBufToBytes(int length) {
		temp = Unpooled.buffer(length);
	}

	public void reading(ByteBuf datas) {
		datas.readBytes(temp, datas.readableBytes());
		if (this.temp.writableBytes() != 0) {
			end = false;
		} else {
			end = true;
		}
	}

	public boolean isEnd() {
		return end;
	}

	public byte[] readFull() {
		if (end) {
			byte[] contentByte = new byte[this.temp.readableBytes()];
			this.temp.readBytes(contentByte);
			this.temp.release();
			return contentByte;
		} else {
			return null;
		}
	}

	public byte[] read(ByteBuf datas) {
		byte[] bytes = new byte[datas.readableBytes()];
		datas.readBytes(bytes);
		return bytes;
	}
}
package rpc.client;

import rpc.common.RpcRequest;
import rpc.common.RpcResponse;

public interface Client {

	/**
	 * 发送请求
	 * @param request
	 * @return
	 * @throws Exception
	 */
	public RpcResponse send(RpcRequest request) throws Exception;
	
}
package rpc.client;

import java.net.URI;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.AdaptiveRecvByteBufAllocator;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.HttpContent;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequestEncoder;
import io.netty.handler.codec.http.HttpResponse;
import io.netty.handler.codec.http.HttpResponseDecoder;
import io.netty.handler.codec.http.HttpVersion;
import rpc.common.RpcRequest;
import rpc.common.RpcResponse;
import rpc.common.SerializationUtil;
import rpc.common.http.ByteBufToBytes;


/**
 * 发起HTTP请求的客户端
 * @author mrh
 *
 */
public class HttpClient extends ChannelInboundHandlerAdapter implements Client {

	private static final Logger LOGGER = LoggerFactory.getLogger(HttpClient.class);

	private ByteBufToBytes reader;
	private String host;
	private int port;

	private RpcResponse response;

	private final Object obj = new Object();

	public HttpClient(String host, int port) {
		this.host = host;
		this.port = port;
	}

	@Override
	public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
		System.out.println("channelRead.........");
		if (msg instanceof HttpResponse) {
			HttpResponse httpresponse = (HttpResponse) msg;
			System.out.println("CONTENT_TYPE:" + httpresponse.headers().get(HttpHeaders.Names.CONTENT_TYPE));
			if (HttpHeaders.isContentLengthSet(httpresponse)) {
				reader = new ByteBufToBytes((int) HttpHeaders.getContentLength(httpresponse));
			}
		}

		if (msg instanceof HttpContent) {
			HttpContent httpContent = (HttpContent) msg;
			ByteBuf content = httpContent.content();
			reader.reading(content);
			content.release();

			if (reader.isEnd()) {
				this.response = SerializationUtil.deserialize(reader.readFull(), RpcResponse.class);
				unLock();
			}
		}
		LOGGER.debug("channelRead..4...end");
	}

	@Override
	public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
		LOGGER.error("4 client caught exception", cause);
		ctx.close();
		unLock();
	}

	

	@Override
	public void channelInactive(ChannelHandlerContext ctx) throws Exception {
		LOGGER.debug("channelInactive................");
	}

	
	@Override
	public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
		LOGGER.debug("channelReadComplete................");
	}

	/**
	 * 发送RPC远端请求
	 * 
	 * @param request
	 * @return
	 * @throws Exception
	 */
	public RpcResponse send(RpcRequest request) throws Exception {
		EventLoopGroup group = new NioEventLoopGroup();
		try {
			this.response = null;
			LOGGER.info("RpcClient...1..send");
			Bootstrap bootstrap = new Bootstrap();
			bootstrap.group(group).channel(NioSocketChannel.class).handler(new ChannelInitializer<SocketChannel>() {
				@Override
				public void initChannel(SocketChannel channel) throws Exception {
					channel.pipeline().addLast(new HttpRequestEncoder()) // 绑定RPC的请求
							.addLast(new HttpResponseDecoder()) // 绑定RPC相应的解析
							.addLast(HttpClient.this); // 设定请求类
				}
			}).option(ChannelOption.SO_KEEPALIVE, true).option(ChannelOption.RCVBUF_ALLOCATOR,
					new AdaptiveRecvByteBufAllocator(64, 131072, 131072));

			ChannelFuture future = bootstrap.connect(host, port).sync();
			
			URI uri = new URI("http://"+host+":"+port);
			DefaultFullHttpRequest httprequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST,  
                    uri.toASCIIString(), Unpooled.wrappedBuffer(SerializationUtil.serialize(request)));  
            // 构建http请求  
			httprequest.headers().set(HttpHeaders.Names.HOST, host);  
			httprequest.headers().set(HttpHeaders.Names.CONNECTION, HttpHeaders.Values.KEEP_ALIVE);  
			httprequest.headers().set(HttpHeaders.Names.CONTENT_LENGTH, httprequest.content().readableBytes());  
			httprequest.headers().set("id", request.getRequestId());  
            
			future.channel().writeAndFlush(httprequest).sync();

			lock();

			future.channel().closeFuture().sync();
			LOGGER.info("RpcClient..4...end");
			return response;
		} finally {
			group.shutdownGracefully();
		}
	}

	/**
	 * 锁定线程等待
	 * 
	 * @param mills
	 * @throws InterruptedException
	 */
	private void lock() throws InterruptedException {
		if (response == null) {
			synchronized (obj) {
				LOGGER.info("RpcClient..2...wait");
				if (response == null) {
					obj.wait();// 等待线程
				}
			}
		}
	}

	private void unLock() {
		synchronized (obj) {
			obj.notifyAll();
		}
	}
}
package rpc.client.proxy;

import java.lang.reflect.Method;
import java.rmi.UnknownHostException;
import java.util.UUID;

import org.apache.log4j.Logger;

import net.sf.cglib.proxy.InvocationHandler;
import net.sf.cglib.proxy.Proxy;
import rpc.client.Client;
import rpc.client.HttpClient;
import rpc.client.RpcClient;
import rpc.client.discover.ZKServiceDiscovery;
import rpc.common.RpcProtocol;
import rpc.common.RpcRequest;
import rpc.common.RpcResponse;

/**
 * RPC代理
 * @author mrh
 *
 */
public class RpcProxy {

	private static final Logger logger = Logger.getLogger(RpcProxy.class);
	private String serverAddress;
	private ZKServiceDiscovery serviceDiscovery;

	
	public RpcProxy(String serverAddress) {
		this.serverAddress = serverAddress;
	}

	public RpcProxy(ZKServiceDiscovery serviceDiscovery) {
		this.serviceDiscovery = serviceDiscovery;
	}

	@SuppressWarnings("unchecked")
	public <T> T create(Class<?> interfaceClass) {
		return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class<?>[] { interfaceClass },
				new InvocationHandler() {
					@Override
					public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
						logger.info("RpcProxy....to send rpc request!");
						RpcRequest request = new RpcRequest(); // 初始化 RPC 请求报错
						request.setRequestId(UUID.randomUUID().toString());
						request.setClassName(method.getDeclaringClass().getName());
						request.setMethodName(method.getName());
						request.setParameterTypes(method.getParameterTypes());
						request.setParameters(args);

						if (serviceDiscovery != null) {
							serverAddress = serviceDiscovery.discover(request.getClassName()); // 发现服务
						}
						if (serverAddress == null) 
							throw new UnknownHostException("没有系统服务,请稍后重试!" + request.getClassName());
						String[] array = serverAddress.split(":");
						String host = array[0];
						int port = Integer.parseInt(array[1]);
						Client client = null;
						//初始化 RPC客户端
						if (array.length > 2 && RpcProtocol.HTTP == RpcProtocol.valueOf(array[2])) {
							client = new HttpClient(host, port);
						} else if (array.length > 2 && RpcProtocol.TCP == RpcProtocol.valueOf(array[2])){
							client = new RpcClient(host, port);
						} else {
							client = new RpcClient(host, port);
						}
						RpcResponse response = client.send(request); // 发送 RPC请求道服务端
						logger.info("RpcProxy....rpc ended!");
						if (response.isError()) {
							logger.error("RpcProxy remote server process failed!", response.getError());
							throw response.getError();
						} else {
							return response.getResult();
						}
					}
				});
	}
	
	/**
	 * 
	 * @param interfaceClass
	 * @param serviceName
	 * @return
	 */
	@SuppressWarnings("unchecked")
	public <T> T create(Class<?> interfaceClass, final String serviceName) {
		return (T) Proxy.newProxyInstance(interfaceClass.getClassLoader(), new Class<?>[] { interfaceClass },
				new InvocationHandler() {
					private String remoteServiceName = serviceName;
					
					@Override
					public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
						logger.info("RpcProxy....to send rpc request!");
						RpcRequest request = new RpcRequest(); // 初始化 RPC 请求报错
						request.setRequestId(UUID.randomUUID().toString());
						request.setClassName(remoteServiceName);
						request.setMethodName(method.getName());
						request.setParameterTypes(method.getParameterTypes());
						request.setParameters(args);

						if (serviceDiscovery != null) {
							serverAddress = serviceDiscovery.discover(request.getClassName()); // 发现服务
						}
						if (serverAddress == null) 
							throw new UnknownHostException("没有系统服务,请稍后重试!" + request.getClassName());
						String[] array = serverAddress.split(":");
						String host = array[0];
						int port = Integer.parseInt(array[1]);

						RpcClient client = new RpcClient(host, port); //初始化 RPC客户端
						RpcResponse response = client.send(request); // 发送 RPC请求道服务端
						logger.info("RpcProxy....rpc ended!");
						if (response.isError()) {
							throw response.getError();
						} else {
							return response.getResult();
						}
					}
				});
	}

}

猜你喜欢

转载自muruiheng.iteye.com/blog/2332289