[Concurrent programming] Java7 - ForkJoin, splitting large tasks into small tasks

1 Introduction

  Java7 provides a framework that can split large tasks into small tasks for execution and then merge the results - Fork/Join. Among them, the process of splitting a large task into enough small tasks to execute concurrently is called Fork, and the process of integrating the results of these small tasks to form the final result is called Join.
  The specific embodiment of the Fork/Join framework is the ForkJoinTask abstract class, which inherits Future and runs in the ForkJoinPool thread pool. There are three implementation classes of this class: RecursiveAction, RecursiveTask, and CountedCompleter (new in Java8).
  Among them, RecursiveAction has no return value, and RecursiveTask has return value. CountedCompleter and RecursiceAction have similar functions, but they have been enhanced when subtasks are blocked or take a long time. It supports rewriting the specified method to return the result, or return no result (return null) , please refer to the detailed explanation below for details.
  It should be noted that the task should not be split as small as possible. According to business needs, it should be split into enough tasks that can be executed at one time. For example, if you need to query 5W pieces of data at a time, you can split it into 10 subtasks , each task queries 5K items, and 10 tasks are executed concurrently.

2. RecursiveAction

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;

import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * RecursiveAction 单元测试
 *
 * @author CL
 */
@Slf4j
public class RecursiveActionTest {

    /**
     * 测试斐波那契数列
     */
    @Test
    public void testFibonacci() {
        int n = 10;

        // 创建线程池,默认线程数为CPU核心数
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Fibonacci task = new Fibonacci(n);
        forkJoinPool.invoke(task);
        int result = task.getResult();

        log.info("第 {} 个斐波那契数为:{}", n, result);
    }

    /**
     * 计算斐波那契数列
     */
    @RequiredArgsConstructor
    private static class Fibonacci extends RecursiveAction {

        private final int n;
        @Getter
        private int result;

        @Override
        protected void compute() {
            if (n <= 1) {
                result = n;
                return;
            }
            Fibonacci f1 = new Fibonacci(n - 1);
            Fibonacci f2 = new Fibonacci(n - 2);
            ForkJoinTask.invokeAll(f1, f2);

            result = f1.getResult() + f2.getResult();
        }

    }

    /**
     * 测试一组数求和
     */
    @Test
    public void testSum() {
        // 构造测试数据
        int total = 5_0000;
        Random random = new Random();
        List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());

        StopWatch stopWatch = new StopWatch();
        stopWatch.start("普通计算");

        int result1 = 0;
        for (Integer n : list) {
            result1 += n;
        }

        stopWatch.stop();

        stopWatch.start("Lambda计算");

        int result2 = list.stream().mapToInt(n -> n).sum();

        stopWatch.stop();

        stopWatch.start("ForkJoin计算");

        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Sum task = new Sum(list);
        forkJoinPool.invoke(task);
        int result3 = task.getResult();

        stopWatch.stop();

        stopWatch.start("ForkJoin非阻塞方式计算");

        ForkJoinPool forkJoinPool2 = new ForkJoinPool();
        Sum2 task2 = new Sum2(list, null);
        forkJoinPool2.invoke(task2);
        int result4 = task2.getResult();

        stopWatch.stop();

        Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误");

        for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
            log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
        }
    }

    /**
     * 一组数求和
     */
    @RequiredArgsConstructor
    private static class Sum extends RecursiveAction {

        private final static int THRESHOLD = 1000;
        private final List<Integer> list;
        @Getter
        private int result;

        @Override
        protected void compute() {
            int total = list.size();
            if (total <= THRESHOLD) {
                result = list.stream().mapToInt(n -> n).sum();
                return;
            }
            int middle = total / 2;
            Sum s1 = new Sum(list.subList(0, middle));
            Sum s2 = new Sum(list.subList(middle, total));
            ForkJoinTask.invokeAll(s1, s2);

            result = s1.getResult() + s2.getResult();
        }

    }

    /**
     * 一组数求和<br/>
     * 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行
     */
    @RequiredArgsConstructor
    private static class Sum2 extends RecursiveAction {

        private final static int THRESHOLD = 1000;
        @Getter
        private final List<Integer> list;
        @Getter
        private final Sum2 subTask;
        @Getter
        private int result;

        @Override
        protected void compute() {
            int total = list.size();
            int start = 0;
            Sum2 tempTask = null;
            // 拆分任务
            while (total > THRESHOLD) {
                tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask);
                tempTask.fork();

                start = start + THRESHOLD;
                total = total - THRESHOLD;
            }
            // 剩余最后一批任务
            int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum();
            // 收集拆分任务结果
            while (tempTask != null) {
                if (tempTask.tryUnfork()) {
                    sum += tempTask.getList().stream().mapToInt(n -> n).sum();
                } else {
                    tempTask.join();
                    sum += tempTask.getResult();
                }
                tempTask = tempTask.getSubTask();
            }

            result = sum;
        }

    }

}

  Fibonacci sequence test results:

20:16:22.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55

  A set of sum test results:

20:28:56.607 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 普通计算 耗时:17 ms
20:28:56.614 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - Lambda计算 耗时:10 ms
20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin计算 耗时:13 ms
20:28:56.615 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - ForkJoin非阻塞方式计算 耗时:10 ms

3. RecursiveTask

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;

import java.util.List;
import java.util.Random;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveTask;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * RecursiveTask 单元测试
 *
 * @author CL
 */
@Slf4j
public class RecursiveTaskTest {

    /**
     * 测试斐波那契数列
     */
    @Test
    public void testFibonacci() {
        int n = 10;

        // 创建线程池,默认线程数为CPU核心数
        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Fibonacci task = new Fibonacci(n);
        int result = forkJoinPool.invoke(task);

        log.info("第 {} 个斐波那契数为:{}", n, result);
    }

    /**
     * 计算斐波那契数列
     */
    @RequiredArgsConstructor
    private static class Fibonacci extends RecursiveTask<Integer> {

        private final int n;

        @Override
        protected Integer compute() {
            if (n <= 1) {
                return n;
            }
            Fibonacci f1 = new Fibonacci(n - 1);
            f1.fork();
            Fibonacci f2 = new Fibonacci(n - 2);

            return f2.compute() + f1.join();
        }

    }

    /**
     * 测试一组数求和
     */
    @Test
    public void testSum() {
        // 构造测试数据
        int total = 5_0000;
        Random random = new Random();
        List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());

        StopWatch stopWatch = new StopWatch();
        stopWatch.start("普通计算");

        int result1 = 0;
        for (Integer n : list) {
            result1 += n;
        }

        stopWatch.stop();

        stopWatch.start("Lambda计算");

        int result2 = list.stream().mapToInt(n -> n).sum();

        stopWatch.stop();

        stopWatch.start("ForkJoin计算");

        ForkJoinPool forkJoinPool = new ForkJoinPool();
        Sum task = new Sum(list);
        int result3 = forkJoinPool.invoke(task);

        stopWatch.stop();

        stopWatch.start("ForkJoin非阻塞方式计算");

        ForkJoinPool forkJoinPool2 = new ForkJoinPool();
        Sum2 task2 = new Sum2(list, null);
        int result4 = forkJoinPool2.invoke(task2);

        stopWatch.stop();

        Assert.isTrue(result1 == result2 && result1 == result3 && result1 == result4, "计算结果错误");

        for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
            log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
        }
    }

    /**
     * 一组数求和
     */
    @RequiredArgsConstructor
    private static class Sum extends RecursiveTask<Integer> {

        private final static int THRESHOLD = 1000;
        private final List<Integer> list;

        @Override
        protected Integer compute() {
            int total = list.size();
            if (total <= THRESHOLD) {
                return list.stream().mapToInt(n -> n).sum();
            }
            int middle = total / 2;
            Sum s1 = new Sum(list.subList(0, middle));
            s1.fork();
            Sum s2 = new Sum(list.subList(middle, total));

            return s2.compute() + s1.join();
        }

    }

    /**
     * 一组数求和<br/>
     * 本级任务不阻塞等待子任务执行,而是循环获取未执行的任务执行
     */
    @RequiredArgsConstructor
    private static class Sum2 extends RecursiveTask<Integer> {

        private final static int THRESHOLD = 1000;
        @Getter
        private final List<Integer> list;
        @Getter
        private final Sum2 subTask;

        @Override
        protected Integer compute() {
            int total = list.size();
            int start = 0;
            Sum2 tempTask = null;
            // 拆分任务
            while (total > THRESHOLD) {
                tempTask = new Sum2(list.subList(start, start + THRESHOLD), tempTask);
                tempTask.fork();

                start = start + THRESHOLD;
                total = total - THRESHOLD;
            }
            // 剩余最后一批任务
            int sum = list.subList(start, list.size()).stream().mapToInt(n -> n).sum();
            // 收集拆分任务结果
            while (tempTask != null) {
                if (tempTask.tryUnfork()) {
                    sum += tempTask.getList().stream().mapToInt(n -> n).sum();
                } else {
                    sum += tempTask.join();
                }
                tempTask = tempTask.getSubTask();
            }

            return sum;
        }

    }

}

  Fibonacci sequence test results:

21:54:32.523 [main] INFO com.c3stones.forkjoin.my.RecursiveActionTest - 第 10 个斐波那契数为:55

  A set of sum test results:

21:54:33.092 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - 普通计算 耗时:7 ms
21:54:33.096 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - Lambda计算 耗时:6 ms
21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin计算 耗时:9 ms
21:54:33.097 [main] INFO com.c3stones.forkjoin.my.RecursiveTaskTest - ForkJoin非阻塞方式计算 耗时:6 ms

4. CountedCompleter

  CountedCompleter是Java8推出的ForkJoinTask实现类,相比于Java7退出的以上两个实现类来说,它在子任务耗时过长或阻塞时更加健壮,但是理解起来比较困难。
  该类有两个成员变量:completer保存当前任务的父任务,如果为null,则表示顶级任务;pending保存当前任务的子任务数,每当子任务执行完成则pending会减1,当pending==0时,表示当前任务执行完成。

public abstract class CountedCompleter<T> extends ForkJoinTask<T> {
    private static final long serialVersionUID = 5232453752276485070L;

    /** This task's completer, or null if none */
    final CountedCompleter<?> completer;
    /** The number of pending tasks until completion */
    volatile int pending;
    

    ......
}

  构造方法:

方法及参数 描述
CountedCompleter() 无父级任务,初始化子任务数为0
CountedCompleter(CountedCompleter<?> completer) 指定父级任务,初始化子任务数为0
CountedCompleter(CountedCompleter<?> completer, int initialPendingCount) 指定父级任务,初始化子任务数

  常用方法:

方法及参数 描述
abstract void compute() 执行的计算任务,必须实现
void onCompletion(CountedCompleter<?> caller) 调用tryComplete后,如果当前任务的所有子任务都完成(当前任务完成),则会调用该方法处理完成后的业务
boolean onExceptionalCompletion(Throwable ex, CountedCompleter<?> caller) 是否向父级任务传递异常,默认传递。即当前任务在执行compute()方法时抛出异常,或者显式的调用completeExceptionally(Throwable ex)抛出异常时,为true则父任务也异常完成,为false不会传递到父级任务
CountedCompleter<?> getCompleter() 返回当前任务的父级任务,没有返回null
int getPendingCount() 获取子任务数
void setPendingCount(int count) 添加子任务数(非原子性)
void addToPendingCount(int delta) 添加子任务数(原子性)
CountedCompleter<?> getRoot() 返回当前任务的父级任务,没有返回自身
void tryComplete() 尝试完成任务。如果当前任务的所有子任务全部完成(pending==0),调用onCompletion(CountedCompleter<?> caller)方法执行完成后的逻辑,并将当前任务的线程状态设置为NORMAL,通知被当前任务阻塞的任务执行,否则通过自旋tr将pending减1。
void propagateCompletion() 尝试完成任务。和tryComplete()的区别是:当所有子任务执行完成后,不会调用onCompletion(CountedCompleter<?> caller)方法。
void complete(T rawResult) 无论当前任务是否完成,直接设置结果为指定 的结果,调用onCompletion(CountedCompleter<?> caller)方法,将当前任务的线程的状态设置为NORMAL。如果存在父级任务,并调用tryComplete()尝试结束子任务。
CountedCompleter<?> firstComplete() 如果当前任务已经完成,则返回当前任务;否则将pending减1,并返回null。通常使用在只需要获取任一任务结果时使用。
CountedCompleter<?> nextComplete() 如果当前任务存在父级任务,则调用父级任务的firstComplete()方法,否则将当前任务的线程的状态设置为NORMAL,并返回null。
void quietlyCompleteRoot() 相当于getRoot().quietlyComplete()
void helpComplete(int maxTasks) 当前任务阻塞等待任务执行时,尝试帮助当前任务执行其他未处理的任务。
void internalPropagateException(Throwable ex) 如果当前任务执行异常,并需要想父级任务传递时,循环传递异常只到最顶级任务
boolean exec() 执行compute()方法,并返回false。ForkJoinPool调用该方法执行任务,如果返回true,ForkJoinPool会将当前任务标记为完成,并通知被当前任务阻塞的其他线程,这也是和RecursiveAction、RecursiveTask区别所在
T getRawResult() 获取结果。方法内部为空,需要重写
void setRawResult(T t) 设置结果。方法内部为空,需要重写
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Test;
import org.springframework.util.Assert;
import org.springframework.util.StopWatch;

import java.util.List;
import java.util.Random;
import java.util.concurrent.CountedCompleter;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * CountedCompleter 单元测试
 *
 * @author CL
 */
@Slf4j
public class CountedCompleterTest {

    /**
     * 测试一组数求和
     */
    @Test
    public void testSum() {
        // 构造测试数据
        int total = 5_0000;
        Random random = new Random();
        List<Integer> list = IntStream.range(1, total + 1).mapToObj(random::nextInt).collect(Collectors.toList());

        StopWatch stopWatch = new StopWatch();
        stopWatch.start("普通计算");

        int result1 = 0;
        for (Integer n : list) {
            result1 += n;
        }

        stopWatch.stop();

        stopWatch.start("Lambda计算");

        int result2 = list.stream().mapToInt(n -> n).sum();

        stopWatch.stop();

        stopWatch.start("ForkJoin计算");

        Sum task = new Sum(null, list, new AtomicReference<>(0));
        int result3 = ForkJoinPool.commonPool().invoke(task);

        stopWatch.stop();

        Assert.isTrue(result1 == result2 && result1 == result3, "计算结果错误");

        for (StopWatch.TaskInfo taskInfo : stopWatch.getTaskInfo()) {
            log.info("{} 耗时:{} ms", taskInfo.getTaskName(), taskInfo.getTimeMillis());
        }
    }

    /**
     * 一组数求和
     */
    private static class Sum extends CountedCompleter<Integer> {

        private final static int THRESHOLD = 1000;
        private List<Integer> list;
        private AtomicReference<Integer> result;

        public Sum(CountedCompleter<Integer> parentTask, List<Integer> list, AtomicReference<Integer> result) {
            super(parentTask);
            this.list = list;
            this.result = result;
        }

        @Override
        public void compute() {
            int total = list.size();
            if (total <= THRESHOLD) {
                result.getAndAccumulate(list.stream().mapToInt(n -> n).sum(), (a, b) -> a + b);

                // 不需要onCompletion(CountedCompleter)方法处理时,可以使用
                propagateCompletion();

                // 期望每个任务完成后继续执行onCompletion(CountedCompleter)方法时,可以使用
//                tryComplete();

                return;
            }
            int middle = total - THRESHOLD;
            List<Integer> subList = list.subList(middle, total);
            list = list.subList(0, middle);
            addToPendingCount(1);
            Sum s1 = new Sum(this, subList, result);
            s1.fork();

            // 继续执行
            this.exec();
        }

        @Override
        public Integer getRawResult() {
            return result.get();
        }

    }

}

  一组数求和测试结果:

22:54:17.123 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - 普通计算 耗时:9 ms
22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - Lambda计算 耗时:8 ms
22:54:17.131 [main] INFO com.c3stones.forkjoin.my.CountedCompleterTest - ForkJoin计算 耗时:9 ms

Guess you like

Origin blog.csdn.net/qq_48008521/article/details/130024018