并发工具类:Fork、Join和CompletionService

一、Fork/Join

Fork就是把一个大任务切分为若干子任务并行的执行,Join就是合并这些子任务的执行结果,最后得到这个大任务的结果

工作窃取算法

假如我们需要做一个比较大的任务,可以把这个任务分割为若干互不依赖的子任务,为了减少线程间的竞争,把这些子任务分别放到不同的队列里,并为每个队列创建一个单独的线程来执行队列里的任务,线程和队列一一对应。比如A线程负责处理A队列里的任务。但是,有的线程会先把自己队列里的任务干完,而其他线程对应的队列里还有任务等待处理。干完活的线程与其等着,不如去帮其他线程干活,于是它就去其他线程的队列里窃取一个任务来执行。而在这时它们会访问同一个队列,所以为了减少窃取任务线程和被窃取任务线程之间的竞争,通常会使用双端队列,被窃取任务线程永远从双端队列的头部拿任务执行,而窃取任务的线程永远从双端队列的尾部拿任务执行

优点:充分利用线程进行并行计算,减少了线程间的竞争

缺点:在某些情况下还是存在竞争,比如双端队列里只有一个任务时,并且该算法会消耗了更多的系统资源,比如创建多个线程和多个双端队列

Fork/Join使用

ForkJoinTask 是一个抽象类,它的方法有很多,最核心的是fork()方法和join()方法,其中fork()方法会异步地执行一个子任务,而join()方法则会阻塞当前线程来等待子任务的执行结果。通常不需要直接继承ForkJoinTask类,只需要继承它的子类:RecursiveAction(用于没有返回结果的任务)、RecursiveTask(用于有返回结果的任务),ForkJoinTask需要通过ForkJoinPool来执行

任务分割出的子任务会添加到当前工作线程所维护的双端队列中,进入队列的头部。当一个工作线程的队列里暂时没有任务时,它会随机从其他工作线程的队列的尾部获取一个任务

案例1:计算斐波那契数列

public class Fibonacci extends RecursiveTask<Integer> {
    private int n;

    public Fibonacci(int n) {
        this.n = n;
    }

    @Override
    protected Integer compute() {
        if (n <= 1) {
            return n;
        }
        Fibonacci f1 = new Fibonacci(n - 1);
        //创建子任务
        f1.fork();
        Fibonacci f2 = new Fibonacci(n - 2);
        //等待子任务,并合并结果
        return f2.compute() + f1.join();
    }
}
        //创建分治任务线程池
        ForkJoinPool fjp = new ForkJoinPool(4);
        //创建分治任务
        Fibonacci fib = new Fibonacci(30);
        //启动分治任务并输出结果
        System.out.println(fjp.invoke(fib));

案例2:模拟MapReduce统计单词数量

public class MR extends RecursiveTask<Map<String, Long>> {
    private String[] fc;
    private int start, end;

    public MR(String[] fc, int start, int end) {
        this.fc = fc;
        this.start = start;
        this.end = end;
    }

    @Override
    protected Map<String, Long> compute() {
        if (end - start == 1) {
            return calc(fc[start]);
        }
        int mid = (end + start) / 2;
        MR mr1 = new MR(fc, start, mid);
        mr1.fork();
        MR mr2 = new MR(fc, mid, end);
        //计算子任务,并返回合并的结果
        return merge(mr2.compute(), mr1.join());
    }

    //合并结果
    private Map<String, Long> merge(Map<String, Long> r1, Map<String, Long> r2) {
        Map<String, Long> result = new HashMap<>();
        result.putAll(r1);
        //合并结果
        r2.forEach((k, v) -> {
            result.merge(k, v, Long::sum);
        });
        return result;
    }

    private Map<String, Long> calc(String line) {
        Map<String, Long> result = new HashMap<>();
        //分割单词
        String[] words = line.split("\\s+");
        for (String word : words) {
            Long value = result.get(word);
            if (value != null) {
                result.put(word, value + 1);
            } else {
                result.put(word, 1L);
            }
        }
        return result;
    }
}
        String[] fc = {"hello world", "hello me", "hello fork", "hello join", "fork join in world"};
        //创建ForkJoin线程池
        ForkJoinPool fjp = new ForkJoinPool(3);
        //创建任务
        MR mr = new MR(fc, 0, fc.length);
        //启动任务
        Map<String, Long> result = fjp.invoke(mr);
        //输出结果
        result.forEach((k, v) -> System.out.println(k + ":" + v));

ForkJoinTask的fork()方法实现原理

当调用ForkJoinTask的fork()方法时,程序会调用ForkJoinWorkerThread的push()方法异步地执行这个任务

    public final ForkJoinTask<V> fork() {
        Thread t;
        if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
            ((ForkJoinWorkerThread)t).workQueue.push(this);
        else
            ForkJoinPool.common.externalPush(this);
        return this;
    }

push()方法把当前任务存放在ForkJoinTask数组队列里,然后再调用ForkJoinPool的signalWork()方法唤醒或创建一个工作线程来执行任务

        final void push(ForkJoinTask<?> task) {
            ForkJoinTask<?>[] a; ForkJoinPool p;
            int b = base, s = top, n;
            if ((a = array) != null) {    // ignore if queue removed
                int m = a.length - 1;     // fenced write for task visibility
                U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
                U.putOrderedInt(this, QTOP, s + 1);
                if ((n = s - b) <= 1) {
                    if ((p = pool) != null)
                        p.signalWork(p.workQueues, this);
                }
                else if (n >= m)
                    growArray();
            }
        }

ForkJoinTask的join()方法实现原理

    public final V join() {
        int s;
        if ((s = doJoin() & DONE_MASK) != NORMAL)
            reportException(s);
        return getRawResult();
    }

当调用ForkJoinTask的join()方法时,程序会调用doJoin()方法,通过doJoin()方法得到当前任务的状态来判断返回什么结果,任务状态有4种:已完成(NORMAL)、被取消(CANCELLED)、信号(SIGNAL)和出现异常(EXCEPTIONAL)

  • 如果任务状态是已完成,则直接返回任务结果
  • 如果任务状态是被取消,则直接抛出CancellationException
  • 如果任务状态是抛出异常,则直接抛出对应的异常
    private int doJoin() {
        int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
        return (s = status) < 0 ? s :
            ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
            (w = (wt = (ForkJoinWorkerThread)t).workQueue).
            tryUnpush(this) && (s = doExec()) < 0 ? s :
            wt.pool.awaitJoin(w, this, 0L) :
            externalAwaitDone();
    }

在doJoin()方法里,首先通过查看任务的状态,看任务是否已经执行完成,如果执行完成,则直接返回任务状态;如果没有执行完,则从任务数组里取出任务并执行。如果任务顺利执行完成,则设置任务状态为NORMAL,如果出现异常,则记录异常,并将任务状态设置为EXCEPTIONAL

二、CompletionService和ExecutorCompletionService

CompletionService接口中定义了一组任务管理的方法:

public interface CompletionService<V> {
    //提交一个Callable类型任务,并返回该任务执行结果关联的Future
    Future<V> submit(Callable<V> task);

  	//提交一个Runnable类型任务,并返回该任务执行结果关联的Future
    Future<V> submit(Runnable task, V result);

  	//如果完成队列中有已完成的任务,take方法就返回任务的结果,否则阻塞等待任务完成
    Future<V> take() throws InterruptedException;

  	//如果完成队列中有数据就返回,否则返回null
    Future<V> poll();

  	//如果完成队列中有数据就直接返回,否则等待指定的时间,到时间后如果还是没有数据就返回null
    Future<V> poll(long timeout, TimeUnit unit) throws InterruptedException;
}

ExecutorCompletionService类是CompletionService接口的实现,它的内部内部维护列一个用于管理已完成的任务的队列,引用了一个Executor用来执行任务,submit()方法最终会委托给内部的Executor去执行任务,take()/poll()方法的工作都委托给内部的已完成任务阻塞队列

    public void demo() throws InterruptedException, ExecutionException {
        //要执行的任务
        List<Task> taskList = new ArrayList<>();
        for (int i = 0; i < 10; ++i) {
            taskList.add(new Task());
        }

        ExecutorService executorService = Executors.newFixedThreadPool(5);
        ExecutorCompletionService<String> completionService = new ExecutorCompletionService<>(executorService);

        int n = taskList.size();
        //提交任务
        for (Task task : taskList) {
            completionService.submit(task);
        }
        //处理任务返回数据
        for (int i = 0; i < n; ++i) {
            Future<String> resultHolder = completionService.take();
            System.out.println("result:" + resultHolder.get());
        }
        System.out.println("task done!");
        executorService.shutdown();
    }

    class Task implements Callable<String> {

        @Override
        public String call() throws Exception {
            TimeUnit.SECONDS.sleep(new Random().nextInt(5));
            return Thread.currentThread().getName();
        }
    }

参考:

《Java并发编程的艺术》

《Java并发编程实战》

https://cloud.tencent.com/developer/article/1444259

https://blog.csdn.net/m0_38031406/article/details/87778215

猜你喜欢

转载自blog.csdn.net/qq_40378034/article/details/107552683