本实验的完整代码详见:https://github.com/Zhang-Qing-Yun/network-lab
目的和内容
- 设计并实现一个基本HTTP 代理服务器。要求在指定端口(例如8080)接收来自客户的HTTP 请求并且根据其中的URL 地址访问该地址所指向的HTTP 服务器(原服务器),接收HTTP 服务器的响应报文,并将响应报文转发给对应的客户进行浏览。
- 设计并实现一个支持Cache 功能的HTTP 代理服务器。要求能缓存原服务器响应的对象,并能够通过修改请求报文(添加if-modified-since头行),向原服务器确认缓存对象是否是最新版本。
- 扩展HTTP 代理服务器,支持如下功能:
a) 网站过滤:允许/不允许访问某些网站;
b) 用户过滤:支持/不支持某些用户访问外部网站;
c) 网站引导:将用户对某个网站的访问引导至一个模拟网站(钓鱼)。
原理
HTTP网络应用通信原理
在HTTP网络应用中,通信的两个进程主要采用客户端/服务器模式(或浏览器/服务器模式),客户端向服务器发送请求,服务器接收到客户端请求后,向客户端提供相应的服务。通信过程如下:
服务器端:
- 服务器端需要首先启动,并绑定一个本地主机端口,在端口上提供服务
- 等待客户端请求
- 接收到客户端请求时,建立起与客户端通信的套接字,开启新线程,将与客户端通信的套接字放入新线程处理
- 返回第二步,主线程继续等待客户端请求。
- 关闭服务器
客户端:
- 根据服务器IP与端口,建立起与服务器通信的socket
- 向服务器发送请求报文,并等待服务器应答
- 请求结束后关闭socket
HTTP代理服务器原理
RFC 7230规定,代理在HTTP通信中扮演一个中间人的角色,对于连接来的客户端来说,它扮演一个服务器的角色;对于要连接的远程服务器,它扮演一个客户端的角色。代理服务器就负责在客户端和服务器之间转发报文。如下图所示:
代理服务器在指定端口监听浏览器的请求,在接收到浏览器的请求时,首先查看浏览器的IP地址,如果来自被限制的IP地址,就向客户端返回错误信息。否则,从请求头中解析出请求的host主机,如果属于不允许访问的主机,则向客户端返回错误信息,如果属于需要引导的host,则修改请求头内所有的地址字段为被引导的地址。之后,根据URL查找是否缓存中是否有该URL的缓存,如果存在,则从中取出Last-modified头部内容,并构造包含If-modified-since的请求头,向服务器发送确认最新版本的报文,并在返回的请求头第一行里确认是否有“Not Modified”,如果存在该字段,则说明本地缓存未过期,直接将本地缓存内容发送给客户端,否则缓存过期,将服务器的报文直接写回客户端。如果缓存中不存在,就直接将客户端请求转发到服务器,并将服务器返回内容缓存后,再返回给客户端。
代理服务器的拦截用户、拦截主机和钓鱼信息都预先配置在配置文件里,并在程序运行后读入程序中,以按照规则执行。
程序运行流程图如下:
代码实验
代理服务器启动并监听客户端连接
public void start() throws IOException {
// 启动服务端
ServerSocket serverSocket = new ServerSocket(port);
startLog();
// 遇到客户端连接就创建一个任务,然后提交到线程池当中,接下来由该线程与客户端保持通信
while (true) {
// 监听客户端连接
Socket socket = serverSocket.accept();
System.out.println("接收到了"+ socket.getInetAddress() + " " + socket.getPort() + "的连接");
// 创建一个任务并提交给线程池处理
threadPool.execute(new ProxyTask(socket));
}
}
线程池
static final class MixedTargetThreadPool {
// 首先从环境变量 mixed.thread.amount 中获取预先配置的线程数
// 如果没有对 mixed.thread.amount 做配置,则使用常量 MIXED_MAX 作为线程数
private static final int max = (null != System.getProperty(MIXED_THREAD_AMOUNT)) ?
Integer.parseInt(System.getProperty(MIXED_THREAD_AMOUNT)) : MIXED_MAX;
// 自定义线程池
private static final ThreadPoolExecutor EXECUTOR = new ThreadPoolExecutor(
MIXED_CORE,
max,
KEEP_ALIVE_SECONDS,
TimeUnit.SECONDS,
new LinkedBlockingQueue(QUEUE_SIZE),
new ThreadPoolExecutor.CallerRunsPolicy()
);
static {
EXECUTOR.allowCoreThreadTimeOut(true);
// 钩子函数,用来关闭线程池
Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
@Override
public void run() {
shutdownThreadPoolGracefully(EXECUTOR);
}
}));
}
}
基于LRU的缓存设计
public class LRUCache implements Cache {
/**
* 默认的代理服务器要缓存的请求的最大容量
*/
private static final int MAX_CAPACITY = 1024;
// 读写锁
private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
// 读锁
private final WriteLock writeLock = lock.writeLock();
// 写锁
private final ReadLock readLock = lock.readLock();
private final LRU<String, byte[]> lru;
public LRUCache() {
lru = new LRU<>(MAX_CAPACITY);
}
public LRUCache(int maxCapacity) {
lru = new LRU<>(maxCapacity);
}
@Override
public void addCache(String url, byte[] content) {
writeLock.lock();
try {
lru.put(url, content);
} finally {
writeLock.unlock();
}
}
@Override
public byte[] getContent(String url) {
readLock.lock();
try {
return lru.get(url);
} finally {
readLock.unlock();
}
}
/**
* 具体的LRU数据结构 <br/>
* 不是线程安全的,需要自行解决线程安全问题
*/
static class LRU<K, V> extends LinkedHashMap<K, V> {
// 最大的容量
private final int maxCapacity;
public LRU(int maxCapacity) {
// accessOrder参数为true时,当调用get和put方法时会将访问到的元素放到双向链表的尾部
super(16, 0.75f, true);
this.maxCapacity = maxCapacity;
}
// 实现LRU的关键方法,如果map里面的元素个数大于了缓存最大容量,则返回true,然后会删除链表的顶端元素eldest
@Override
public boolean removeEldestEntry(Map.Entry<K, V> eldest){
return size() > maxCapacity;
}
}
}
加载配置文件
protected void initConfig() {
InputStream inputStream = null;
BufferedReader urlReader = null;
BufferedReader userReader = null;
BufferedReader fishingReader = null;
try {
// 读取主配置文件proxy.properties
inputStream = this.getClass().getClassLoader().getResourceAsStream("proxy.properties");
Properties properties = new Properties();
properties.load(inputStream);
// 加载主配置
SingletonFactory factory = SingletonFactory.getInstance();
ProxyConfig config = factory.getObject(ProxyConfig.class);
config.setUrlRule(Integer.parseInt(properties.getProperty("urlRule")));
config.setUserRule(Integer.parseInt(properties.getProperty("userRule")));
// 设置配置的url,文件里一行就是一个url
InputStream urlStream = this.getClass().getClassLoader().getResourceAsStream("url.txt");
if (urlStream != null) {
urlReader = new BufferedReader(new InputStreamReader(urlStream));
List<String> urls = new ArrayList<>();
String line;
while ((line = urlReader.readLine()) != null) {
urls.add(line);
}
config.setUrls(urls);
}
// 设置配置的User即主机地址,一行就是一个地址
InputStream userStream = this.getClass().getClassLoader().getResourceAsStream("user.txt");
if (userStream != null) {
userReader = new BufferedReader(new InputStreamReader(userStream));
List<String> users = new ArrayList<>();
String line;
while ((line = userReader.readLine()) != null) {
users.add(line);
}
config.setUsers(users);
}
// 设置要被钓鱼的用户,一行就是就是一个用户即主机地址
InputStream fishingStream = this.getClass().getClassLoader().getResourceAsStream("fishing.txt");
if (fishingStream != null) {
fishingReader = new BufferedReader(new InputStreamReader(fishingStream));
List<String> fishingUsers = new ArrayList<>();
String line;
while ((line = fishingReader.readLine()) != null) {
fishingUsers.add(line);
}
config.setFishingUsers(fishingUsers);
}
} catch (IOException e) {
e.printStackTrace();
System.out.println("配置文件不存在或格式不正确");
} finally {
// 关闭资源
if (inputStream != null) {
try {
inputStream.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (urlReader != null) {
try {
urlReader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (userReader != null) {
try {
userReader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
if (fishingReader != null) {
try {
fishingReader.close();
} catch (IOException e) {
e.printStackTrace();
}
}
}
}
提交给线程池的任务(主要的业务逻辑)
package com.qingyun.network.task;
import com.google.common.primitives.Bytes;
import com.qingyun.network.cache.LRUCache;
import com.qingyun.network.config.ProxyConfig;
import com.qingyun.network.constants.ProxyConstants;
import com.qingyun.network.factory.SingletonFactory;
import com.qingyun.network.util.IOUtil;
import org.apache.commons.lang3.ArrayUtils;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
/**
* @description: 具体执行代理业务的任务,目前只能做到一次请求一个TCP连接即HTTP1.0的情景
* @author: 張青云
* @create: 2021-10-27 20:07
**/
public class ProxyTask implements Runnable {
// 用于和客户端通信的TCP套接字
private final Socket socket;
// 用户和目的服务器建立连接的TCP套接字
private Socket targetSocket;
// 缓存
private final LRUCache cache;
// 配置信息
private final ProxyConfig config;
public ProxyTask(Socket socket) {
this.socket = socket;
cache = SingletonFactory.getInstance().getObject(LRUCache.class);
config = SingletonFactory.getInstance().getObject(ProxyConfig.class);
}
@Override
public void run() {
InputStream clientInputStream;
OutputStream clientOutputStream;
String url = null;
String host = null;
int port = 80;
StringBuffer buffer = new StringBuffer(); // HTTP请求头的字符形式
try {
clientInputStream = socket.getInputStream();
clientOutputStream = socket.getOutputStream();
// 解析HTTP请求头
String line;
while ((line = IOUtil.readHttpLine(clientInputStream)) != null) {
if (line.startsWith("GET")) {
// GET /index.html HTTP1.1
url = line.split(" ")[1];
} else if (line.startsWith("Host")) {
// Host: 127.0.0.1:80
host = line.split(" ")[1];
}
buffer.append(line).append("\r\n");
}
buffer.append("\r\n");
if (host == null) {
// TODO:处理没有带host字段的请求
return;
}
// 解析地址和端口号,如果没有端口号则使用默认的80
String[] split = host.split(":");
host = split[0];
if (split.length != 1) {
port = Integer.parseInt(split[1]);
}
// 网站过滤
if (config.getUrlRule() == ProxyConstants.ALLOW_URL) {
// 如果当前访问的网站没在配置文件中则拦截
if (!config.getUrls().contains(host)) {
clientOutputStream.write(refuseProxy().getBytes());
clientOutputStream.flush();
return;
}
} else if (config.getUrlRule() == ProxyConstants.REFUSE_URL) {
// 如果要访问的网站存在于配置文件中则拦截
if (config.getUrls().contains(host)) {
clientOutputStream.write(refuseProxy().getBytes());
clientOutputStream.flush();
return;
}
} else {
// 配置文件写错了
clientOutputStream.write(refuseProxy().getBytes());
clientOutputStream.flush();
return;
}
// 用户过滤
String clientHost = socket.getInetAddress().getHostAddress();
if (config.getUserRule() == ProxyConstants.ALLOW_USER) {
// 如果客户端的Host不在配置文件里拦截
if (config.getUsers().contains(clientHost)) {
clientOutputStream.write(refuseProxy().getBytes());
clientOutputStream.flush();
return;
}
} else if (config.getUserRule() == ProxyConstants.REFUSE_USER) {
// 如果客户端的Host在配置文件里拦截
if (config.getUsers().contains(clientHost)) {
clientOutputStream.write(refuseProxy().getBytes());
clientOutputStream.flush();
return;
}
} else {
// 配置文件写错了
clientOutputStream.write(refuseProxy().getBytes());
clientOutputStream.flush();
return;
}
// 钓鱼
if (config.getFishingUsers().contains(clientHost)) {
// 构造发送给钓鱼网站的HTTP报文
StringBuffer fishingHTTP = new StringBuffer();
fishingHTTP.append("GET " + ProxyConstants.fishingUrl + " HTTP/1.1" + "\r\n");
fishingHTTP.append("Host: " + ProxyConstants.fishingHost + "\r\n");
fishingHTTP.append("\r\n");
String fishingHTTPStr = fishingHTTP.toString();
// 建立连接然后发送数据
String[] hostAndPort = ProxyConstants.fishingHost.split(":");
targetSocket = new Socket(hostAndPort[0], Integer.parseInt(hostAndPort[1]));
OutputStream outputStream = targetSocket.getOutputStream();
outputStream.write(fishingHTTPStr.getBytes());
waitTargetServerAndTransfer(clientOutputStream, targetSocket.getInputStream());
return;
}
// 对于非GET请求的方法,直接转发给目的服务器
if (url == null) {
transfer(host, port, buffer, clientInputStream, clientOutputStream);
return;
}
String uri = url;
byte[] content = cache.getContent(uri);
// 对于GET方法,如果缓存中存在则向目的服务器发送条件GET
if (content != null) {
// 从缓存中提取Last-Modified值
String lastModified = parseLastModified(content);
// 构造条件GET请求
StringBuffer ifGetReqBuffer = new StringBuffer();
ifGetReqBuffer.append("GET " + url + " HTTP/1.1\r\n");
ifGetReqBuffer.append("Host: " + host + ":" + port + "\r\n");
ifGetReqBuffer.append("If-modified-since: " + lastModified + "\r\n");
ifGetReqBuffer.append("\r\n");
String ifGetReq = ifGetReqBuffer.toString();
// 向目的服务器发送
targetSocket = new Socket(host, port);
OutputStream outputStream = targetSocket.getOutputStream();
InputStream inputStream = targetSocket.getInputStream();
outputStream.write(ifGetReq.getBytes());
outputStream.flush();
// 阻塞式监听目的服务器的返回值
String respFirstLine = IOUtil.readHttpLine(inputStream);
int code = Integer.parseInt(respFirstLine.split(" ")[1]);
// 缓存过期
if (code != 304) {
System.out.println("代理服务器对" + uri + "的缓存过期");
// 将报文转发至客户端
byte[] firstLine = (respFirstLine + "\r\n").getBytes();
clientOutputStream.write(firstLine);
byte[] resp = waitTargetServerAndTransfer(clientOutputStream, inputStream);
// 将响应结果进行缓存,只缓存具有Last-Modified首部行的响应结果
if (new String(resp).contains("Last-Modified")) {
cache.addCache(uri, ArrayUtils.addAll(firstLine, resp));
System.out.println("对" + uri + "的响应结果进行了缓存");
}
} else {
// 缓存命中
System.out.println("对" + uri + "的访问命中缓存");
// 直接返回缓存中的值
clientOutputStream.write(content);
clientOutputStream.flush();
}
} else {
// 缓存不存在,则直接请求目的服务器,然后转发给客户端,并在代理服务器进行缓存
System.out.println("代理服务器没有对" + uri + "请求的缓存");
byte[] resp = transfer(host, port, buffer, null, clientOutputStream);
// 将响应结果进行缓存,只缓存具有Last-Modified首部行的响应结果
if (new String(resp).contains("Last-Modified")) {
cache.addCache(uri, resp);
System.out.println("对" + uri + "的响应结果进行了缓存");
}
}
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
if (socket != null) {
socket.close();
}
if (targetSocket != null) {
targetSocket.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* 在客户端和目标服务器之间进行转发,也就是将客户端的内容直接发送到目的服务器,然后再将目的服务器返回的内容直接转发给客户端
* @param host 目标服务器主机地址
* @param port 目标服务器端口号
* @param head HTTP的请求头
* @param body HTTP除去head后的内容的输入流
* @param clientOutputStream 客户端socket的输出流
* @return 目标服务器返回的相应内容
*/
private byte[] transfer(String host, int port, StringBuffer head, InputStream body, OutputStream clientOutputStream) throws IOException {
// 和远程服务器建立连接
// TODO:有BUG,可能连不上目标服务器
targetSocket = new Socket(host, port);
InputStream targetServerInputStream = targetSocket.getInputStream();
OutputStream targetServerOutputStream = targetSocket.getOutputStream();
// 先写入请求头
targetServerOutputStream.write(head.toString().getBytes());
// 请求体不为null时写入请求体
if (body != null) {
byte[] bytes = new byte[256 * 1024];
int size;
// TODO:有BUG,可能读不到完整数据;但是如果while循环读的话,如果目标服务器不关闭TCP连接,则会阻塞在这里
if ((size = body.read(bytes)) >= 0) {
targetServerOutputStream.write(bytes, 0, size);
}
}
targetServerOutputStream.flush();
// 同步阻塞式等待目标服务器返回响应
return waitTargetServerAndTransfer(clientOutputStream, targetServerInputStream);
}
/**
* 同步阻塞式等待目标服务器返回响应,并且将响应结果直接返回给客户端
* @param clientOutputStream 到客户端的输出流
* @param targetServerInputStream 到目的服务器的输入流
* @return 客户端的响应结果
*/
private byte[] waitTargetServerAndTransfer(OutputStream clientOutputStream,
InputStream targetServerInputStream) throws IOException {
List<byte[]> response = new ArrayList<>();
byte[] bytes = new byte[256 * 1024];
int length;
// TODO:有BUG,可能读不到完整数据;但是如果while循环读的话,如果目标服务器不关闭TCP连接,则会阻塞在这里
if ((length = targetServerInputStream.read(bytes)) >= 0) {
// 写回给客户端
clientOutputStream.write(bytes, 0, length);
// 收集响应结果
byte[] part = new byte[length];
System.arraycopy(bytes, 0, part, 0, length);
response.add(part);
}
// 将响应结果返回
List<Byte> list = new LinkedList<>();
for (byte[] one: response) {
list.addAll(Bytes.asList(one));
}
return Bytes.toArray(list);
}
/**
* 从缓存中提取Last-Modified值
* @param context HTTP报文
* @return Last-Modified值,如果没有则返回null
*/
private String parseLastModified(byte[] context) {
StringBuffer headLine = new StringBuffer();
for (int i = 0; i < context.length; i++) {
if (context[i] == '\r') {
// 请求头解析结束时都没有找到Last-Modified请求行
if (headLine.length() == 0) {
return null;
}
String s = headLine.toString();
if (s.startsWith("Last-Modified")) {
return s.substring(15);
}
i++;
headLine = new StringBuffer();
continue;
}
headLine.append((char) context[i]);
}
return null;
}
/**
* 拒绝代理时向客户端返回的HTTP报文
*/
private String refuseProxy() {
String resp = "HTTP/1.1 500 Internal Server Error\r\n";
resp += "\r\n";
return resp;
}
}
这里只是列出了重要代码,如需查看完整代码,或者想要参考我的编程风格,请移步到该实验的代码仓库去查看,相信你一定会有收获的。(代码仓库:https://github.com/Zhang-Qing-Yun/network-lab)