Java多线程(7)手写线程池

上一篇:Java多线程(6)CAS详解

手写线程池

一. 线程阻塞获取任务

1. 编写任务队列 以生产者和消费者模型来实现

class TaskQueue<T> {
    // 1. 任务队列实例
    private Deque<T> queue = new ArrayDeque<>();
    // 2. 任务队列大小
    private final static int queueSize = 10;
    // 3. 锁
    private ReentrantLock lock = new ReentrantLock();
    // 4. 生产者等待的条件变量
    private Condition producerWaitSet = lock.newCondition();
    // 5. 消费者等待的条件变量
    private Condition consumerWaitSet = lock.newCondition();
    // 生产者添加任务(阻塞)
    public void addTask(T task){
        // 加锁
        lock.lock();
        try{
            // 如果 任务队列满了
            while(queue.size() == queueSize){
                try {
                    // 生产者(指的是添加任务的主线程) 等待
                    producerWaitSet.await();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // 把任务添加到集合最后面的一个位置
            queue.addLast(task);
            // 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
            consumerWaitSet.signal();
        } finally {
            //释放锁
            lock.unlock();
        }
    }
    // 消费者获取任务(阻塞)
    public T getTask(){
        lock.lock();
        try{
            // 如果任务(即生产者没有添加任务)为空
            while(queue.isEmpty()){
                try {
                    // 则让 消费者(这里的消费者指的是线程池的线程) 等待
                    consumerWaitSet.await();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // 拿出第一个任务
            T t = queue.removeFirst();
            // 唤醒生产者(即添加任务的主线程)线程,继续添加线程
            producerWaitSet.signal();
            return t;
        } finally {
            //释放锁
            lock.unlock();
        }
    }
}

这里难在对生产者消费者模型的理解上,以线程池为例,

  • 消费者就是线程池里面的每一个线程
  • 生产者就是添加任务的主线程

2. 编写线程池

class MyThreadPool {
    // 1. 寄存的任务队列
    private TaskQueue<Runnable> taskQueue = new TaskQueue<>();
    // 2. 线程集合 这里用封装的内部类Worker实现
    private HashSet<Worker> workers = new HashSet<>();
    // 3. 线程数
    private int coreSize;
    public MyThreadPool(int coreSize){
        this.coreSize = coreSize;
    }

    // 执行任务 即主线程(生产者) 添加任务
    public void executeTask(Runnable task){
        // 判断还有没有 空闲的线程
        if (workers.size() < coreSize){
            // 如果有,就新建线程
            Worker worker = new Worker(task);
            // 添加进去,表示少了一个线程
            workers.add(worker);
            worker.start();
        } else {
            // 没有空闲的线程的话 就添加到任务队列  等待线程池的每个线程执行完当前的进程
            taskQueue.addTask(task);
        }
    }

    class Worker extends Thread {
        private Runnable task;
        public Worker(Runnable task) {
            this.task = task;
        }
        @Override
        public void run() {
            // 如果当前任务执行完毕,并且获取不到新的任务下就退出while
            while (task != null || (task = taskQueue.getTask())!=null){
                try {
                    task.run();
                } catch (Exception e){
                    e.printStackTrace();
                } finally {
                    task = null;
                }
            }
            synchronized (workers){
                // 执行完任务就移除,表示多了一个线程
                workers.remove(this);
            }
        }
    }

}

这里的线程池给了一个有参的构造方法,参数是线程池的最大容量(即最大可同时执行的线程大小),然后给了一个开始任务的方法executeTask,后面的我不多说了,懂得都懂,不懂得多敲几编,代码注释都加上了,顺带我个人的理解

3. 测试代码

3.1 情况1 线程池容量10 给安排8个任务

public class TestThreadPool {
    // 测试代码
    public static void main(String[] args) {
        MyThreadPool threadPool = new MyThreadPool(10);
        //让10个线程去 执行 8 个任务
        for (int i = 0; i < 8; i++) {
            int j = i+1;
            threadPool.executeTask(new Runnable() {
                @Override
                public void run() {
                    System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
                }
            });
        }
    }
}

在这里插入图片描述

3.2 情况2 线程池容量10 给安排30个任务

public class TestThreadPool {
    // 测试代码
    public static void main(String[] args) {
        MyThreadPool threadPool = new MyThreadPool(10);
        //让10个线程去 执行 30 个任务
        for (int i = 0; i < 30; i++) {
            int j = i+1;
            threadPool.executeTask(new Runnable() {
                @Override
                public void run() {
                    System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
                }
            });
        }
    }

}

运行结果:
在这里插入图片描述

3.3 总结

可以发现任务都执行完了,诡异的是为什么主线程还是没有停止下来

这个时候可以看线程池内部类worker的 run方法 里面的whlie循环

在这里插入图片描述

因为这个方法让每个都进入了wait,所以我们在下面改造一下

二. 线程非阻塞获取任务

1. 生产者消费者模型

在上面的基础上新增一个增强版获取任务的方法

// 消费者获取任务增强版(非阻塞) (传递两个参数  1是时间 2是时间单位)
    public T getTaskEnhance(long timeout, TimeUnit timeUnit){
        lock.lock();
        try{
            // 统一时间格式
            long nanos = timeUnit.toNanos(timeout);
            while(queue.isEmpty()){
                try {
                    // 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
                    if (nanos <= 0){
                        return null;
                    }
                    // awaitNanos方法 对比 await(时间,时间单位) 方法的区别就是 如果不到等待时间被打断 他会返回剩余时间
                    nanos = consumerWaitSet.awaitNanos(nanos);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            T t = queue.removeFirst();
            producerWaitSet.signal();
            return t;
        } finally {
            lock.unlock();
        }
    }

2. 线程池实现

因为之前的getTask方法是阻塞方法,所以我们可以调用这个新加的增强版方法,通过构造方法的方式传递时间,和时间单位

class MyThreadPool {
    // 1. 寄存的任务队列
    private TaskQueue<Runnable> taskQueue = new TaskQueue<>();
    // 2. 线程集合 这里用封装的内部类Worker实现
    private HashSet<Worker> workers = new HashSet<>();
    // 3. 线程数
    private int coreSize;
    // 4. 获取任务的超时时间
    private long timeout;
    // 5. 时间单位
    private TimeUnit timeUnit;
    public MyThreadPool(int coreSize, long timeout, TimeUnit timeUnit){
        this.coreSize = coreSize;
        this.timeout = timeout;
        this.timeUnit = timeUnit;
    }

    // 执行任务 即主线程(生产者) 添加任务
    public void executeTask(Runnable task){
        // 判断还有没有 空闲的线程
        if (workers.size() < coreSize){
            // 如果有,就新建线程
            Worker worker = new Worker(task);
            // 添加进去,表示少了一个线程
            workers.add(worker);
            worker.start();
        } else {
            // 没有空闲的线程的话 就添加到任务队列  等待线程池的每个线程执行完当前的进程
            taskQueue.addTask(task);
        }
    }

    class Worker extends Thread {
        private Runnable task;
        public Worker(Runnable task) {
            this.task = task;
        }
        @Override
        public void run() {
            // 如果当前任务执行完毕,并且获取不到新的任务下就退出while
            while (task != null || (task = taskQueue.getTaskEnhance(timeout,timeUnit))!=null){
                try {
                    task.run();
                } catch (Exception e){
                    e.printStackTrace();
                } finally {
                    task = null;
                }
            }
            synchronized (workers){
                // 执行完任务就移除,表示多了一个线程
                workers.remove(this);
            }
        }
    }
}

3. 测试代码

3.1 运行结果

这里我们只测试一种情况就是线程容量为10,任务个数为30

public class TestThreadPool {
    // 测试代码
    public static void main(String[] args) {
        MyThreadPool threadPool = new MyThreadPool(10,100,TimeUnit.MILLISECONDS);
        //让10个线程去 执行 30 个任务
        for (int i = 0; i < 30; i++) {
            int j = i+1;
            threadPool.executeTask(new Runnable() {
                @Override
                public void run() {
                    System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
                }
            });
        }
    }
}

运行结果:
在这里插入图片描述

可以看到我们的主线程已经关闭了,任务也按预想的执行完成了

3.2 总结

还有一种情况没有测试 ,就是executeTask方法中

在这里插入图片描述

这里,添加的任务超过了任务队列最大容量,那么主线程就会一种死等,等待任务添加完成,这里的解决方案和上面的添加方法增强版是一样的

// 生产者添加任务增强版(非阻塞)(传递两个参数  1是时间 2是时间单位)
    public boolean addTaskEnhance(T task,long timeout, TimeUnit timeUnit){
        // 加锁
        lock.lock();
        try{
            // 统一时间格式
            long nanos = timeUnit.toNanos(timeout);
            // 如果 任务队列满了
            while(queue.size() == queueSize){
                try {
                    // 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
                    if (nanos <= 0){
                        return false;
                    }
                    // 生产者(指的是添加任务的主线程) 等待
                    nanos = producerWaitSet.awaitNanos(nanos);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // 把任务添加到集合最后面的一个位置
            queue.addLast(task);
            // 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
            consumerWaitSet.signal();
        } finally {
            //释放锁
            lock.unlock();
        }
        return true;
    }

还有就是通过策略模式让用户自己选择

三. 策略模式(队列已满的情况下让用户选择死等还是其他操作)

1. 新建策略模式的接口类

@FunctionalInterface
interface RejectPolicy<T> {
    void reject(TaskQueue<T> queue, T task);
}

2. 生产者消费者模型

class TaskQueue<T> {
    // 1. 任务队列实例
    private Deque<T> queue = new ArrayDeque<>();
    // 2. 任务队列大小
    private final static int queueSize = 10;
    // 3. 锁
    private ReentrantLock lock = new ReentrantLock();
    // 4. 生产者等待的条件变量
    private Condition producerWaitSet = lock.newCondition();
    // 5. 消费者等待的条件变量
    private Condition consumerWaitSet = lock.newCondition();
    // 生产者添加任务(阻塞)
    public void addTask(T task){
        // 加锁
        lock.lock();
        try{
            // 如果 任务队列满了
            while(queue.size() == queueSize){
                try {
                    // 生产者(指的是添加任务的主线程) 等待
                    producerWaitSet.await();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // 把任务添加到集合最后面的一个位置
            queue.addLast(task);
            // 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
            consumerWaitSet.signal();
        } finally {
            //释放锁
            lock.unlock();
        }
    }
    // 生产者添加任务增强版(非阻塞)(传递两个参数  1是时间 2是时间单位)
    public boolean addTaskEnhance(T task,long timeout, TimeUnit timeUnit){
        // 加锁
        lock.lock();
        try{
            // 统一时间格式
            long nanos = timeUnit.toNanos(timeout);
            // 如果 任务队列满了
            while(queue.size() == queueSize){
                try {
                    // 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
                    if (nanos <= 0){
                        return false;
                    }
                    // 生产者(指的是添加任务的主线程) 等待
                    nanos = producerWaitSet.awaitNanos(nanos);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // 把任务添加到集合最后面的一个位置
            queue.addLast(task);
            // 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
            consumerWaitSet.signal();
        } finally {
            //释放锁
            lock.unlock();
        }
        return true;
    }
    // 消费者获取任务(阻塞)
    public T getTask(){
        lock.lock();
        try{
            // 如果任务(即生产者没有添加任务)为空
            while(queue.isEmpty()){
                try {
                    // 则让 消费者(这里的消费者指的是线程池的线程) 等待
                    consumerWaitSet.await();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            // 拿出第一个任务
            T t = queue.removeFirst();
            // 唤醒生产者(即添加任务的主线程)线程,继续添加线程
            producerWaitSet.signal();
            return t;
        } finally {
            //释放锁
            lock.unlock();
        }
    }
    // 消费者获取任务增强版(非阻塞) (传递两个参数  1是时间 2是时间单位)
    public T getTaskEnhance(long timeout, TimeUnit timeUnit){
        lock.lock();
        try{
            // 统一时间格式
            long nanos = timeUnit.toNanos(timeout);
            while(queue.isEmpty()){
                try {
                    // 判断有没有超过获取的超时时间 这个判断会在虚假唤醒的情况下执行
                    if (nanos <= 0){
                        return null;
                    }
                    // awaitNanos方法 对比 await(时间,时间单位) 的区别就是 如果不到等待时间被打断 他会返回剩余时间
                    nanos = consumerWaitSet.awaitNanos(nanos);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
            T t = queue.removeFirst();
            producerWaitSet.signal();
            return t;
        } finally {
            lock.unlock();
        }
    }

    // 生产者添加,如果任务队列满了的话,策略模式选让用户选择如何处理
    public void tryAddTask(RejectPolicy<T> reject, T task){
        // 加锁
        lock.lock();
        try{
            // 如果 任务队列满了
            if (queue.size() == queueSize){
                reject .reject(this,task);
            }
            // 把任务添加到集合最后面的一个位置
            queue.addLast(task);
            // 添加成功就唤醒 消费者(这里的消费者指的是线程池的线程)
            consumerWaitSet.signal();
        } finally {
            //释放锁
            lock.unlock();
        }
    }
}

3. 编写线程池

class MyThreadPool<T> {
    // 1. 寄存的任务队列
    private TaskQueue<Runnable> taskQueue = new TaskQueue<>();
    // 2. 线程集合 这里用封装的内部类Worker实现
    private HashSet<Worker> workers = new HashSet<>();
    // 3. 线程数
    private int coreSize;
    // 4. 获取任务的超时时间
    private long timeout;
    // 5. 时间单位
    private TimeUnit timeUnit;
    private RejectPolicy<Runnable> reject;
    public MyThreadPool(int coreSize, long timeout, TimeUnit timeUnit,RejectPolicy<Runnable> reject){
        this.coreSize = coreSize;
        this.timeout = timeout;
        this.timeUnit = timeUnit;
        this.reject = reject;
    }

    // 执行任务 即主线程(生产者) 添加任务
    public void executeTask(Runnable task){
        // 判断还有没有 空闲的线程
        if (workers.size() < coreSize){
            // 如果有,就新建线程
            Worker worker = new Worker(task);
            // 添加进去,表示少了一个线程
            workers.add(worker);
            worker.start();
        } else {
            // 没有空闲的线程的话 就添加到任务队列  等待线程池的每个线程执行完当前的进程
            taskQueue.tryAddTask(reject,task);
        }
    }

    class Worker extends Thread {
        private Runnable task;
        public Worker(Runnable task) {
            this.task = task;
        }
        @Override
        public void run() {
            // 如果当前任务执行完毕,并且获取不到新的任务下就退出while
            while (task != null || (task = taskQueue.getTaskEnhance(timeout,timeUnit))!=null){
                try {
                    task.run();
                } catch (Exception e){
                    e.printStackTrace();
                } finally {
                    task = null;
                }
            }
            synchronized (workers){
                // 执行完任务就移除,表示多了一个线程
                workers.remove(this);
            }
        }
    }
}

重点在这里:
在这里插入图片描述

4. 测试代码

public class TestThreadPool {
    // 测试代码
    public static void main(String[] args) {
        MyThreadPool threadPool = new MyThreadPool(10, 100, TimeUnit.MILLISECONDS, new RejectPolicy<Runnable>() {
            @Override
            public void reject(TaskQueue<Runnable> queue, Runnable task) {
                //queue.addTask(task); // 如果任务队列满了,就死等
                queue.addTaskEnhance(task,100,TimeUnit.MILLISECONDS); // 如果任务队列满了,就等100 毫秒,199毫秒添加不上就算了
               //  throw new RuntimeException(); 抛出异常

            }
        });
        //让10个线程去 执行 30 个任务
        for (int i = 0; i < 40; i++) {
            int j = i+1;
            threadPool.executeTask(new Runnable() {
                @Override
                public void run() {
                    try {
                        Thread.sleep(3000);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                    System.out.println(Thread.currentThread().getName() + "执行了第" + j + "个任务");
                }
            });
        }
    }
}

三种情况,生产者(主线程)想用哪种就哪种

猜你喜欢

转载自blog.csdn.net/haiyanghan/article/details/109559493