实现一个HTTP代理服务器(哈工大计网实验一Java版)

本实验的完整代码详见:https://github.com/Zhang-Qing-Yun/network-lab

目的和内容

  1. 设计并实现一个基本HTTP 代理服务器。要求在指定端口(例如8080)接收来自客户的HTTP 请求并且根据其中的URL 地址访问该地址所指向的HTTP 服务器(原服务器),接收HTTP 服务器的响应报文,并将响应报文转发给对应的客户进行浏览。
  2. 设计并实现一个支持Cache 功能的HTTP 代理服务器。要求能缓存原服务器响应的对象,并能够通过修改请求报文(添加if-modified-since头行),向原服务器确认缓存对象是否是最新版本。
  3. 扩展HTTP 代理服务器,支持如下功能:
    a) 网站过滤:允许/不允许访问某些网站;
    b) 用户过滤:支持/不支持某些用户访问外部网站;
    c) 网站引导:将用户对某个网站的访问引导至一个模拟网站(钓鱼)。

原理

HTTP网络应用通信原理

在HTTP网络应用中,通信的两个进程主要采用客户端/服务器模式(或浏览器/服务器模式),客户端向服务器发送请求,服务器接收到客户端请求后,向客户端提供相应的服务。通信过程如下:
1

服务器端:

  1. 服务器端需要首先启动,并绑定一个本地主机端口,在端口上提供服务
  2. 等待客户端请求
  3. 接收到客户端请求时,建立起与客户端通信的套接字,开启新线程,将与客户端通信的套接字放入新线程处理
  4. 返回第二步,主线程继续等待客户端请求。
  5. 关闭服务器

客户端:

  1. 根据服务器IP与端口,建立起与服务器通信的socket
  2. 向服务器发送请求报文,并等待服务器应答
  3. 请求结束后关闭socket

HTTP代理服务器原理

RFC 7230规定,代理在HTTP通信中扮演一个中间人的角色,对于连接来的客户端来说,它扮演一个服务器的角色;对于要连接的远程服务器,它扮演一个客户端的角色。代理服务器就负责在客户端和服务器之间转发报文。如下图所示:
在这里插入图片描述
代理服务器在指定端口监听浏览器的请求,在接收到浏览器的请求时,首先查看浏览器的IP地址,如果来自被限制的IP地址,就向客户端返回错误信息。否则,从请求头中解析出请求的host主机,如果属于不允许访问的主机,则向客户端返回错误信息,如果属于需要引导的host,则修改请求头内所有的地址字段为被引导的地址。之后,根据URL查找是否缓存中是否有该URL的缓存,如果存在,则从中取出Last-modified头部内容,并构造包含If-modified-since的请求头,向服务器发送确认最新版本的报文,并在返回的请求头第一行里确认是否有“Not Modified”,如果存在该字段,则说明本地缓存未过期,直接将本地缓存内容发送给客户端,否则缓存过期,将服务器的报文直接写回客户端。如果缓存中不存在,就直接将客户端请求转发到服务器,并将服务器返回内容缓存后,再返回给客户端。

代理服务器的拦截用户、拦截主机和钓鱼信息都预先配置在配置文件里,并在程序运行后读入程序中,以按照规则执行。

程序运行流程图如下:
1

代码实验

代理服务器启动并监听客户端连接

    public void start() throws IOException {
    
    
        //  启动服务端
        ServerSocket serverSocket = new ServerSocket(port);
        startLog();
        //  遇到客户端连接就创建一个任务,然后提交到线程池当中,接下来由该线程与客户端保持通信
        while (true) {
    
    
            //  监听客户端连接
            Socket socket = serverSocket.accept();
            System.out.println("接收到了"+ socket.getInetAddress() + " " + socket.getPort() + "的连接");
            //  创建一个任务并提交给线程池处理
            threadPool.execute(new ProxyTask(socket));
        }
    }

线程池

static final class MixedTargetThreadPool {
    
    
    //  首先从环境变量 mixed.thread.amount 中获取预先配置的线程数
    //  如果没有对 mixed.thread.amount 做配置,则使用常量 MIXED_MAX 作为线程数
    private static final int max = (null != System.getProperty(MIXED_THREAD_AMOUNT)) ?
            Integer.parseInt(System.getProperty(MIXED_THREAD_AMOUNT)) : MIXED_MAX;
    //  自定义线程池
    private static final ThreadPoolExecutor EXECUTOR = new ThreadPoolExecutor(
            MIXED_CORE,
            max,
            KEEP_ALIVE_SECONDS,
            TimeUnit.SECONDS,
            new LinkedBlockingQueue(QUEUE_SIZE),
            new ThreadPoolExecutor.CallerRunsPolicy()
    );
    static {
    
    
        EXECUTOR.allowCoreThreadTimeOut(true);
        //  钩子函数,用来关闭线程池
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() {
    
    
            @Override
            public void run() {
    
    
                shutdownThreadPoolGracefully(EXECUTOR);
            }
        }));
    }
}

基于LRU的缓存设计

public class LRUCache implements Cache {
    
    
    /**
     * 默认的代理服务器要缓存的请求的最大容量
     */
    private static final int MAX_CAPACITY = 1024;

    //  读写锁
    private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

    //  读锁
    private final WriteLock writeLock = lock.writeLock();

    //  写锁
    private final ReadLock readLock = lock.readLock();

    private final LRU<String, byte[]> lru;


    public LRUCache() {
    
    
        lru = new LRU<>(MAX_CAPACITY);
    }

    public LRUCache(int maxCapacity) {
    
    
        lru = new LRU<>(maxCapacity);
    }

    @Override
    public void addCache(String url, byte[] content) {
    
    
        writeLock.lock();
        try {
    
    
            lru.put(url, content);
        } finally {
    
    
            writeLock.unlock();
        }
    }

    @Override
    public byte[] getContent(String url) {
    
    
        readLock.lock();
        try {
    
    
            return lru.get(url);
        } finally {
    
    
            readLock.unlock();
        }
    }


    /**
     * 具体的LRU数据结构 <br/>
     * 不是线程安全的,需要自行解决线程安全问题
     */
    static class LRU<K, V> extends LinkedHashMap<K, V> {
    
    
        //  最大的容量
        private final int maxCapacity;

        public LRU(int maxCapacity) {
    
    
            //  accessOrder参数为true时,当调用get和put方法时会将访问到的元素放到双向链表的尾部
            super(16, 0.75f, true);
            this.maxCapacity = maxCapacity;
        }

        //  实现LRU的关键方法,如果map里面的元素个数大于了缓存最大容量,则返回true,然后会删除链表的顶端元素eldest
        @Override
        public boolean removeEldestEntry(Map.Entry<K, V> eldest){
    
    
            return size() > maxCapacity;
        }
    }
}

加载配置文件

protected void initConfig() {
    
    
    InputStream inputStream = null;
    BufferedReader urlReader = null;
    BufferedReader userReader = null;
    BufferedReader fishingReader = null;

    try {
    
    
        //  读取主配置文件proxy.properties
        inputStream = this.getClass().getClassLoader().getResourceAsStream("proxy.properties");
        Properties properties = new Properties();
        properties.load(inputStream);

        //  加载主配置
        SingletonFactory factory = SingletonFactory.getInstance();
        ProxyConfig config = factory.getObject(ProxyConfig.class);
        config.setUrlRule(Integer.parseInt(properties.getProperty("urlRule")));
        config.setUserRule(Integer.parseInt(properties.getProperty("userRule")));

        //  设置配置的url,文件里一行就是一个url
        InputStream urlStream = this.getClass().getClassLoader().getResourceAsStream("url.txt");
        if (urlStream != null) {
    
    
            urlReader = new BufferedReader(new InputStreamReader(urlStream));
            List<String> urls = new ArrayList<>();
            String line;
            while ((line = urlReader.readLine()) != null) {
    
    
                urls.add(line);
            }
            config.setUrls(urls);
        }

        //  设置配置的User即主机地址,一行就是一个地址
        InputStream userStream = this.getClass().getClassLoader().getResourceAsStream("user.txt");
        if (userStream != null) {
    
    
            userReader = new BufferedReader(new InputStreamReader(userStream));
            List<String> users = new ArrayList<>();
            String line;
            while ((line = userReader.readLine()) != null) {
    
    
                users.add(line);
            }
            config.setUsers(users);
        }

        //  设置要被钓鱼的用户,一行就是就是一个用户即主机地址
        InputStream fishingStream = this.getClass().getClassLoader().getResourceAsStream("fishing.txt");
        if (fishingStream != null) {
    
    
            fishingReader = new BufferedReader(new InputStreamReader(fishingStream));
            List<String> fishingUsers = new ArrayList<>();
            String line;
            while ((line = fishingReader.readLine()) != null) {
    
    
                fishingUsers.add(line);
            }
            config.setFishingUsers(fishingUsers);
        }

    } catch (IOException e) {
    
    
        e.printStackTrace();
        System.out.println("配置文件不存在或格式不正确");
    } finally {
    
    
        //  关闭资源
        if (inputStream != null) {
    
    
            try {
    
    
                inputStream.close();
            } catch (IOException e) {
    
    
                e.printStackTrace();
            }
        }
        if (urlReader != null) {
    
    
            try {
    
    
                urlReader.close();
            } catch (IOException e) {
    
    
                e.printStackTrace();
            }
        }
        if (userReader != null) {
    
    
            try {
    
    
                userReader.close();
            } catch (IOException e) {
    
    
                e.printStackTrace();
            }
        }
        if (fishingReader != null) {
    
    
            try {
    
    
                fishingReader.close();
            } catch (IOException e) {
    
    
                e.printStackTrace();
            }
        }
    }
}

提交给线程池的任务(主要的业务逻辑)

package com.qingyun.network.task;

import com.google.common.primitives.Bytes;
import com.qingyun.network.cache.LRUCache;
import com.qingyun.network.config.ProxyConfig;
import com.qingyun.network.constants.ProxyConstants;
import com.qingyun.network.factory.SingletonFactory;
import com.qingyun.network.util.IOUtil;
import org.apache.commons.lang3.ArrayUtils;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.Socket;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;

/**
 * @description: 具体执行代理业务的任务,目前只能做到一次请求一个TCP连接即HTTP1.0的情景
 * @author: 張青云
 * @create: 2021-10-27 20:07
 **/
public class ProxyTask implements Runnable {
    
    
    //  用于和客户端通信的TCP套接字
    private final Socket socket;

    //  用户和目的服务器建立连接的TCP套接字
    private Socket targetSocket;

    //  缓存
    private final LRUCache cache;

    //  配置信息
    private final ProxyConfig config;

    public ProxyTask(Socket socket) {
    
    
        this.socket = socket;
        cache = SingletonFactory.getInstance().getObject(LRUCache.class);
        config = SingletonFactory.getInstance().getObject(ProxyConfig.class);
    }

    @Override
    public void run() {
    
    
        InputStream clientInputStream;
        OutputStream clientOutputStream;
        String url = null;
        String host = null;
        int port = 80;
        StringBuffer buffer = new StringBuffer();  // HTTP请求头的字符形式

        try {
    
    
            clientInputStream = socket.getInputStream();
            clientOutputStream = socket.getOutputStream();

            //  解析HTTP请求头
            String line;
            while ((line = IOUtil.readHttpLine(clientInputStream)) != null) {
    
    
                if (line.startsWith("GET")) {
    
    
                    //  GET /index.html HTTP1.1
                    url = line.split(" ")[1];
                } else if (line.startsWith("Host")) {
    
    
                    //  Host: 127.0.0.1:80
                    host = line.split(" ")[1];
                }
                buffer.append(line).append("\r\n");
            }
            buffer.append("\r\n");

            if (host == null) {
    
    
                //  TODO:处理没有带host字段的请求
                return;
            }

            //  解析地址和端口号,如果没有端口号则使用默认的80
            String[] split = host.split(":");
            host = split[0];
            if (split.length != 1) {
    
    
                port = Integer.parseInt(split[1]);
            }

            //  网站过滤
            if (config.getUrlRule() == ProxyConstants.ALLOW_URL) {
    
    
                //  如果当前访问的网站没在配置文件中则拦截
                if (!config.getUrls().contains(host)) {
    
    
                    clientOutputStream.write(refuseProxy().getBytes());
                    clientOutputStream.flush();
                    return;
                }
            } else if (config.getUrlRule() == ProxyConstants.REFUSE_URL) {
    
    
                //  如果要访问的网站存在于配置文件中则拦截
                if (config.getUrls().contains(host)) {
    
    
                    clientOutputStream.write(refuseProxy().getBytes());
                    clientOutputStream.flush();
                    return;
                }
            } else {
    
      // 配置文件写错了
                clientOutputStream.write(refuseProxy().getBytes());
                clientOutputStream.flush();
                return;
            }

            //  用户过滤
            String clientHost = socket.getInetAddress().getHostAddress();
            if (config.getUserRule() == ProxyConstants.ALLOW_USER) {
    
    
                //  如果客户端的Host不在配置文件里拦截
                if (config.getUsers().contains(clientHost)) {
    
    
                    clientOutputStream.write(refuseProxy().getBytes());
                    clientOutputStream.flush();
                    return;
                }
            } else if (config.getUserRule() == ProxyConstants.REFUSE_USER) {
    
    
                //  如果客户端的Host在配置文件里拦截
                if (config.getUsers().contains(clientHost)) {
    
    
                    clientOutputStream.write(refuseProxy().getBytes());
                    clientOutputStream.flush();
                    return;
                }
            } else {
    
      // 配置文件写错了
                clientOutputStream.write(refuseProxy().getBytes());
                clientOutputStream.flush();
                return;
            }

            //  钓鱼
            if (config.getFishingUsers().contains(clientHost)) {
    
    
                //  构造发送给钓鱼网站的HTTP报文
                StringBuffer fishingHTTP = new StringBuffer();
                fishingHTTP.append("GET " + ProxyConstants.fishingUrl + " HTTP/1.1" + "\r\n");
                fishingHTTP.append("Host: " + ProxyConstants.fishingHost + "\r\n");
                fishingHTTP.append("\r\n");
                String fishingHTTPStr = fishingHTTP.toString();

                //  建立连接然后发送数据
                String[] hostAndPort = ProxyConstants.fishingHost.split(":");
                targetSocket = new Socket(hostAndPort[0], Integer.parseInt(hostAndPort[1]));
                OutputStream outputStream = targetSocket.getOutputStream();
                outputStream.write(fishingHTTPStr.getBytes());
                waitTargetServerAndTransfer(clientOutputStream, targetSocket.getInputStream());
                return;
            }

            //  对于非GET请求的方法,直接转发给目的服务器
            if (url == null) {
    
    
                transfer(host, port, buffer, clientInputStream, clientOutputStream);
                return;
            }

            String uri = url;
            byte[] content = cache.getContent(uri);
            //  对于GET方法,如果缓存中存在则向目的服务器发送条件GET
            if (content != null) {
    
    
                //  从缓存中提取Last-Modified值
                String lastModified = parseLastModified(content);
                //  构造条件GET请求
                StringBuffer ifGetReqBuffer = new StringBuffer();
                ifGetReqBuffer.append("GET " + url + " HTTP/1.1\r\n");
                ifGetReqBuffer.append("Host: " + host + ":" + port + "\r\n");
                ifGetReqBuffer.append("If-modified-since: " + lastModified + "\r\n");
                ifGetReqBuffer.append("\r\n");
                String ifGetReq = ifGetReqBuffer.toString();
                //  向目的服务器发送
                targetSocket = new Socket(host, port);
                OutputStream outputStream = targetSocket.getOutputStream();
                InputStream inputStream = targetSocket.getInputStream();
                outputStream.write(ifGetReq.getBytes());
                outputStream.flush();
                //  阻塞式监听目的服务器的返回值
                String respFirstLine = IOUtil.readHttpLine(inputStream);
                int code = Integer.parseInt(respFirstLine.split(" ")[1]);
                //  缓存过期
                if (code != 304) {
    
    
                    System.out.println("代理服务器对" + uri + "的缓存过期");
                    //  将报文转发至客户端
                    byte[] firstLine = (respFirstLine + "\r\n").getBytes();
                    clientOutputStream.write(firstLine);
                    byte[] resp = waitTargetServerAndTransfer(clientOutputStream, inputStream);
                    //  将响应结果进行缓存,只缓存具有Last-Modified首部行的响应结果
                    if (new String(resp).contains("Last-Modified")) {
    
    
                        cache.addCache(uri, ArrayUtils.addAll(firstLine, resp));
                        System.out.println("对" + uri + "的响应结果进行了缓存");
                    }
                } else {
    
      // 缓存命中
                    System.out.println("对" + uri + "的访问命中缓存");
                    //  直接返回缓存中的值
                    clientOutputStream.write(content);
                    clientOutputStream.flush();
                }
            } else {
    
     //  缓存不存在,则直接请求目的服务器,然后转发给客户端,并在代理服务器进行缓存
                System.out.println("代理服务器没有对" + uri + "请求的缓存");
                byte[] resp = transfer(host, port, buffer, null, clientOutputStream);
                //  将响应结果进行缓存,只缓存具有Last-Modified首部行的响应结果
                if (new String(resp).contains("Last-Modified")) {
    
    
                    cache.addCache(uri, resp);
                    System.out.println("对" + uri + "的响应结果进行了缓存");
                }
            }
        } catch (Exception e) {
    
    
            e.printStackTrace();
        } finally {
    
    
            try {
    
    
                if (socket != null) {
    
    
                    socket.close();
                }
                if (targetSocket != null) {
    
    
                    targetSocket.close();
                }
            } catch (IOException e) {
    
    
                e.printStackTrace();
            }
        }
    }

    /**
     * 在客户端和目标服务器之间进行转发,也就是将客户端的内容直接发送到目的服务器,然后再将目的服务器返回的内容直接转发给客户端
     * @param host 目标服务器主机地址
     * @param port 目标服务器端口号
     * @param head HTTP的请求头
     * @param body HTTP除去head后的内容的输入流
     * @param clientOutputStream 客户端socket的输出流
     * @return 目标服务器返回的相应内容
     */
    private byte[] transfer(String host, int port, StringBuffer head, InputStream body, OutputStream clientOutputStream) throws IOException {
    
    
        //  和远程服务器建立连接
        //  TODO:有BUG,可能连不上目标服务器
        targetSocket = new Socket(host, port);
        InputStream targetServerInputStream = targetSocket.getInputStream();
        OutputStream targetServerOutputStream = targetSocket.getOutputStream();

        //  先写入请求头
        targetServerOutputStream.write(head.toString().getBytes());
        //  请求体不为null时写入请求体
        if (body != null) {
    
    
            byte[] bytes = new byte[256 * 1024];
            int size;
            // TODO:有BUG,可能读不到完整数据;但是如果while循环读的话,如果目标服务器不关闭TCP连接,则会阻塞在这里
            if ((size = body.read(bytes)) >= 0) {
    
    
                targetServerOutputStream.write(bytes, 0, size);
            }
        }
        targetServerOutputStream.flush();

        //  同步阻塞式等待目标服务器返回响应
        return waitTargetServerAndTransfer(clientOutputStream, targetServerInputStream);
    }

    /**
     * 同步阻塞式等待目标服务器返回响应,并且将响应结果直接返回给客户端
     * @param clientOutputStream 到客户端的输出流
     * @param targetServerInputStream 到目的服务器的输入流
     * @return 客户端的响应结果
     */
    private byte[] waitTargetServerAndTransfer(OutputStream clientOutputStream,
                                               InputStream targetServerInputStream) throws IOException {
    
    
        List<byte[]> response = new ArrayList<>();
        byte[] bytes = new byte[256 * 1024];
        int length;
        // TODO:有BUG,可能读不到完整数据;但是如果while循环读的话,如果目标服务器不关闭TCP连接,则会阻塞在这里
        if ((length = targetServerInputStream.read(bytes)) >= 0) {
    
    
            //  写回给客户端
            clientOutputStream.write(bytes, 0, length);
            //  收集响应结果
            byte[] part = new byte[length];
            System.arraycopy(bytes, 0, part, 0, length);
            response.add(part);
        }

        //  将响应结果返回
        List<Byte> list = new LinkedList<>();
        for (byte[] one: response) {
    
    
            list.addAll(Bytes.asList(one));
        }
        return Bytes.toArray(list);
    }

    /**
     * 从缓存中提取Last-Modified值
     * @param context HTTP报文
     * @return Last-Modified值,如果没有则返回null
     */
    private String parseLastModified(byte[] context) {
    
    
        StringBuffer headLine = new StringBuffer();
        for (int i = 0; i < context.length; i++) {
    
    
            if (context[i] == '\r') {
    
    
                //  请求头解析结束时都没有找到Last-Modified请求行
                if (headLine.length() == 0) {
    
    
                    return null;
                }
                String s = headLine.toString();
                if (s.startsWith("Last-Modified")) {
    
    
                    return s.substring(15);
                }
                i++;
                headLine = new StringBuffer();
                continue;
            }
            headLine.append((char) context[i]);
        }
        return null;
    }

    /**
     * 拒绝代理时向客户端返回的HTTP报文
     */
    private String refuseProxy() {
    
    
        String resp = "HTTP/1.1 500 Internal Server Error\r\n";
        resp += "\r\n";
        return resp;
    }
}

这里只是列出了重要代码,如需查看完整代码,或者想要参考我的编程风格,请移步到该实验的代码仓库去查看,相信你一定会有收获的。(代码仓库:https://github.com/Zhang-Qing-Yun/network-lab

猜你喜欢

转载自blog.csdn.net/zhang_qing_yun/article/details/121050618