Fork/Join mode in JDK 7-easy implementation of parallel computing in the multi-core era

Introduction

As multi-core chips gradually become mainstream, most software developers inevitably need to understand parallel programming knowledge. At the same time, mainstream programming languages ​​are incorporating more and more parallel features into the standard library or the language itself. We can see that JDK is also in the forefront of the trend in this regard. In JDK Standard Edition 5, the parallel framework provided by Doug Lea became part of the standard library (JSR-166). Later, in JDK 6, some new parallel features, such as the parallel collection framework, were merged into the standard library (JSR-166x). Until today, although Java SE 7 has not been officially released, some parallel-related new features have appeared in JSR-166y:

  1. Fork/Join mode;

  2. TransferQueue, which inherits from BlockingQueue and can block the "producer" when the queue is full;

  3. ArrayTasks/ListTasks, a class used to execute some array/list related tasks in parallel;

  4. IntTasks/LongTasks/DoubleTasks, a tool class for parallel processing of numeric arrays, providing sorting, searching, summing, minimum, and maximum functions;

Among them, the support for Fork/Join mode may be the most common new feature for developing parallel software. In JSR-166y, Doug Lea used the Fork/Join pattern extensively when implementing ArrayTasks/ListTasks/IntTasks/LongTasks/DoubleTasks. Readers also need to pay attention, because JDK 7 has not been officially released, so the functions and release versions involved in this article may be different.

Fork/Join mode has its own scope of application. If an application can be decomposed into multiple subtasks, and the results of multiple subtasks can be combined to get the final answer, then this application is suitable for solving in the Fork/Join mode. Figure 1 shows a schematic diagram of the Fork/Join mode. The Task at the top of the diagram depends on the execution of the Task below it. Only when all the subtasks are completed, the caller can get the return result of Task 0.

Figure 1. Schematic diagram of Fork/Join pattern

It can be said that the Fork/Join model can solve many kinds of parallel problems. By using the Fork/Join framework provided by Doug Lea, software developers only need to focus on task division and the combination of intermediate results to fully utilize the excellent performance of the parallel platform. Many other difficult problems related to parallelism, such as load balancing and synchronization, can be solved in a unified way by the framework. In this way, we can easily obtain the benefits of parallelism and avoid the difficulty and error-prone shortcomings of parallel programming.

Use Fork/Join mode

Before starting to try the Fork/Join mode, we need to download the source code of JSR-166y from the Concurrency JSR-166 Interest Site hosted by Doug Lea, and we also need to install the latest version of JDK 6 (for download URL, please refer to Reference Resources). The usage of Fork/Join mode is very intuitive. First, we need to write a ForkJoinTask to complete the division of subtasks and the merging of intermediate results. Subsequently, we handed this ForkJoinTask to ForkJoinPool to complete the execution of the application.

Usually we do not directly inherit ForkJoinTask, it contains too many abstract methods. For specific problems, we can choose different subclasses of ForkJoinTask to complete the task. RecursiveAction is a subclass of ForkJoinTask. It represents the simplest type of ForkJoinTask: no return value is required. After the subtasks are executed, there is no need to combine intermediate results. If we inherit from RecursiveAction, then we only need to overload the  protected void compute() method. Next, let's take a look at how to create a ForkJoinTask subclass for the quick sort algorithm:

Listing 1. Subclass of ForkJoinTask

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

class SortTask extends RecursiveAction {

    final long[] array;

    final int lo;

    final int hi;

    private int THRESHOLD = 30;

 

    public SortTask(long[] array) {

        this.array = array;

        this.lo = 0;

        this.hi = array.length - 1;

    }

 

    public SortTask(long[] array, int lo, int hi) {

        this.array = array;

        this.lo = lo;

        this.hi = hi;

    }

 

    protected void compute() {

        if (hi - lo < THRESHOLD)

            sequentiallySort(array, lo, hi);

        else {

            int pivot = partition(array, lo, hi);

            coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,

                pivot + 1, hi));

        }

    }

 

    private int partition(long[] array, int lo, int hi) {

        long x = array[hi];

        int i = lo - 1;

        for (int j = lo; j < hi; j++) {

            if (array[j] <= x) {

                i++;

                swap(array, i, j);

            }

        }

        swap(array, i + 1, hi);

        return i + 1;

    }

 

    private void swap(long[] array, int i, int j) {

        if (i != j) {

            long temp = array[i];

            array[i] = array[j];

            array[j] = temp;

        }

    }

 

    private void sequentiallySort(long[] array, int lo, int hi) {

        Arrays.sort(array, lo, hi + 1);

    }

}

在 清单 1 中,SortTask 首先通过 partition() 方法将数组分成两个部分。随后,两个子任务将被生成并分别排序数组的两个部分。当子任务足够小时,再将其分割为更小的任务反而引起性能的降低。因此,这里我们使用一个 THRESHOLD,限定在子任务规模较小时,使用直接排序,而不是再将其分割成为更小的任务。其中,我们用到了 RecursiveAction 提供的方法 coInvoke()。它表示:启动所有的任务,并在所有任务都正常结束后返回。如果其中一个任务出现异常,则其它所有的任务都取消。coInvoke() 的参数还可以是任务的数组。

现在剩下的工作就是将 SortTask 提交到 ForkJoinPool 了。ForkJoinPool() 默认建立具有与 CPU 可使用线程数相等线程个数的线程池。我们在一个 JUnit 的 test 方法中将 SortTask 提交给一个新建的 ForkJoinPool:

清单 2. 新建的 ForkJoinPool

1

2

3

4

5

6

7

8

9

10

11

@Test

public void testSort() throws Exception {

    ForkJoinTask sort = new SortTask(array);

    ForkJoinPool fjpool = new ForkJoinPool();

    fjpool.submit(sort);

    fjpool.shutdown();

 

    fjpool.awaitTermination(30, TimeUnit.SECONDS);

 

    assertTrue(checkSorted(array));

}

在上面的代码中,我们用到了 ForkJoinPool 提供的如下函数:

  1. submit():将 ForkJoinTask 类的对象提交给 ForkJoinPool,ForkJoinPool 将立刻开始执行 ForkJoinTask。

  2. shutdown():执行此方法之后,ForkJoinPool 不再接受新的任务,但是已经提交的任务可以继续执行。如果希望立刻停止所有的任务,可以尝试 shutdownNow() 方法。

  3. awaitTermination():阻塞当前线程直到 ForkJoinPool 中所有的任务都执行结束。

并行快速排序的完整代码如下所示:

清单 3. 并行快速排序的完整代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

5

package tests;

 

import static org.junit.Assert.*;

 

import java.util.Arrays;

import java.util.Random;

import java.util.concurrent.TimeUnit;

 

import jsr166y.forkjoin.ForkJoinPool;

import jsr166y.forkjoin.ForkJoinTask;

import jsr166y.forkjoin.RecursiveAction;

 

import org.junit.Before;

import org.junit.Test;

 

class SortTask extends RecursiveAction {

    final long[] array;

    final int lo;

    final int hi;

    private int THRESHOLD = 0; //For demo only

 

    public SortTask(long[] array) {

        this.array = array;

        this.lo = 0;

        this.hi = array.length - 1;

    }

 

    public SortTask(long[] array, int lo, int hi) {

        this.array = array;

        this.lo = lo;

        this.hi = hi;

    }

 

    protected void compute() {

        if (hi - lo < THRESHOLD)

            sequentiallySort(array, lo, hi);

        else {

            int pivot = partition(array, lo, hi);

            System.out.println("\npivot = " + pivot + ", low = " + lo + ", high = " + hi);

            System.out.println("array" + Arrays.toString(array));

            coInvoke(new SortTask(array, lo, pivot - 1), new SortTask(array,

                    pivot + 1, hi));

        }

    }

 

    private int partition(long[] array, int lo, int hi) {

        long x = array[hi];

        int i = lo - 1;

        for (int j = lo; j < hi; j++) {

            if (array[j] <= x) {

                i++;

                swap(array, i, j);

            }

        }

        swap(array, i + 1, hi);

        return i + 1;

    }

 

    private void swap(long[] array, int i, int j) {

        if (i != j) {

            long temp = array[i];

            array[i] = array[j];

            array[j] = temp;

        }

    }

 

    private void sequentiallySort(long[] array, int lo, int hi) {

        Arrays.sort(array, lo, hi + 1);

    }

}

 

public class TestForkJoinSimple {

    private static final int NARRAY = 16; //For demo only

    long[] array = new long[NARRAY];

    Random rand = new Random();

 

    @Before

    public void setUp() {

        for (int i = 0; i < array.length; i++) {

            array[i] = rand.nextLong()%100; //For demo only

        }

        System.out.println("Initial Array: " + Arrays.toString(array));

    }

 

    @Test

    public void testSort() throws Exception {

        ForkJoinTask sort = new SortTask(array);

        ForkJoinPool fjpool = new ForkJoinPool();

        fjpool.submit(sort);

        fjpool.shutdown();

 

        fjpool.awaitTermination(30, TimeUnit.SECONDS);

 

        assertTrue(checkSorted(array));

    }

 

    boolean checkSorted(long[] a) {

        for (int i = 0; i < a.length - 1; i++) {

            if (a[i] > (a[i + 1])) {

                return false;

            }

        }

        return true;

    }

}

运行以上代码,我们可以得到以下结果:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

Initial Array: [46, -12, 74, -67, 76, -13, -91, -96]

 

pivot = 0, low = 0, high = 7

array[-96, -12, 74, -67, 76, -13, -91, 46]

 

pivot = 5, low = 1, high = 7

array[-96, -12, -67, -13, -91, 46, 76, 74]

 

pivot = 1, low = 1, high = 4

array[-96, -91, -67, -13, -12, 46, 74, 76]

 

pivot = 4, low = 2, high = 4

array[-96, -91, -67, -13, -12, 46, 74, 76]

 

pivot = 3, low = 2, high = 3

array[-96, -91, -67, -13, -12, 46, 74, 76]

 

pivot = 2, low = 2, high = 2

array[-96, -91, -67, -13, -12, 46, 74, 76]

 

pivot = 6, low = 6, high = 7

array[-96, -91, -67, -13, -12, 46, 74, 76]

 

pivot = 7, low = 7, high = 7

array[-96, -91, -67, -13, -12, 46, 74, 76]

Fork/Join 模式高级特性

使用 RecursiveTask

除了 RecursiveAction,Fork/Join 框架还提供了其他 ForkJoinTask 子类:带有返回值的 RecursiveTask,使用 finish() 方法显式中止的 AsyncAction 和 LinkedAsyncAction,以及可使用 TaskBarrier 为每个任务设置不同中止条件的 CyclicAction。

从 RecursiveTask 继承的子类同样需要重载 protected void compute() 方法。与 RecursiveAction 稍有不同的是,它可使用泛型指定一个返回值的类型。下面,我们来看看如何使用 RecursiveTask 的子类。

清单 4. RecursiveTask 的子类

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

1

class Fibonacci extends RecursiveTask<Integer> {

    final int n;

 

    Fibonacci(int n) {

        this.n = n;

    }

 

    private int compute(int small) {

        final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };

        return results[small];

    }

 

    public Integer compute() {

        if (n <= 10) {

            return compute(n);

        }

        Fibonacci f1 = new Fibonacci(n - 1);

        Fibonacci f2 = new Fibonacci(n - 2);

        f1.fork();

        f2.fork();

        return f1.join() + f2.join();

    }

}

在 清单 4 中, Fibonacci 的返回值为 Integer 类型。其 compute() 函数首先建立两个子任务,启动子任务执行,阻塞以等待子任务的结果返回,相加后得到最终结果。同样,当子任务足够小时,通过查表得到其结果,以减小因过多地分割任务引起的性能降低。其中,我们用到了 RecursiveTask 提供的方法 fork() 和 join()。它们分别表示:子任务的异步执行和阻塞等待结果完成。

现在剩下的工作就是将 Fibonacci 提交到 ForkJoinPool 了,我们在一个 JUnit 的 test 方法中作了如下处理:

清单 5. 将 Fibonacci 提交到 ForkJoinPool

1

2

3

4

5

6

7

8

9

@Test

public void testFibonacci() throws InterruptedException, ExecutionException {

    ForkJoinTask<Integer> fjt = new Fibonacci(45);

    ForkJoinPool fjpool = new ForkJoinPool();

    Future<Integer> result = fjpool.submit(fjt);

 

    // do something

    System.out.println(result.get());

}

使用 CyclicAction 来处理循环任务

CyclicAction 的用法稍微复杂一些。如果一个复杂任务需要几个线程协作完成,并且线程之间需要在某个点等待所有其他线程到达,那么我们就能方便的用 CyclicAction 和 TaskBarrier 来完成。图 2 描述了使用 CyclicAction 和 TaskBarrier 的一个典型场景。

图 2. 使用 CyclicAction 和 TaskBarrier 执行多线程任务

继承自 CyclicAction 的子类需要 TaskBarrier 为每个任务设置不同的中止条件。从 CyclicAction 继承的子类需要重载 protected void compute() 方法,定义在 barrier 的每个步骤需要执行的动作。compute() 方法将被反复执行直到 barrier 的 isTerminated() 方法返回 True。TaskBarrier 的行为类似于 CyclicBarrier。下面,我们来看看如何使用 CyclicAction 的子类。

清单 6. 使用 CyclicAction 的子类

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

class ConcurrentPrint extends RecursiveAction {

    protected void compute() {

        TaskBarrier b = new TaskBarrier() {

            protected boolean terminate(int cycle, int registeredParties) {

                System.out.println("Cycle is " + cycle + ";"

                        + registeredParties + " parties");

                return cycle >= 10;

            }

        };

        int n = 3;

        CyclicAction[] actions = new CyclicAction[n];

        for (int i = 0; i < n; ++i) {

            final int index = i;

            actions[i] = new CyclicAction(b) {

                protected void compute() {

                    System.out.println("I'm working " + getCycle() + " "

                            + index);

                    try {

                        Thread.sleep(500);

                    } catch (InterruptedException e) {

                        e.printStackTrace();

                    }

                }

            };

        }

        for (int i = 0; i < n; ++i)

            actions[i].fork();

        for (int i = 0; i < n; ++i)

            actions[i].join();

    }

}

在 清单 6 中,CyclicAction[] 数组建立了三个任务,打印各自的工作次数和序号。而在 b.terminate() 方法中,我们设置的中止条件表示重复 10 次计算后中止。现在剩下的工作就是将 ConcurrentPrint 提交到 ForkJoinPool 了。我们可以在 ForkJoinPool 的构造函数中指定需要的线程数目,例如 ForkJoinPool(4) 就表明线程池包含 4 个线程。我们在一个 JUnit 的 test 方法中运行 ConcurrentPrint 的这个循环任务:

清单 7. 运行 ConcurrentPrint 循环任务

1

2

3

4

5

6

7

@Test

public void testBarrier () throws InterruptedException, ExecutionException {

    ForkJoinTask fjt = new ConcurrentPrint();

    ForkJoinPool fjpool = new ForkJoinPool(4);

    fjpool.submit(fjt);

    fjpool.shutdown();

}

RecursiveTask 和 CyclicAction 两个例子的完整代码如下所示:

清单 8. RecursiveTask 和 CyclicAction 两个例子的完整代码

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

5

package tests;

 

import java.util.concurrent.ExecutionException;

import java.util.concurrent.Future;

 

import jsr166y.forkjoin.CyclicAction;

import jsr166y.forkjoin.ForkJoinPool;

import jsr166y.forkjoin.ForkJoinTask;

import jsr166y.forkjoin.RecursiveAction;

import jsr166y.forkjoin.RecursiveTask;

import jsr166y.forkjoin.TaskBarrier;

 

import org.junit.Test;

 

class Fibonacci extends RecursiveTask<Integer> {

    final int n;

 

    Fibonacci(int n) {

        this.n = n;

    }

 

    private int compute(int small) {

        final int[] results = { 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89 };

        return results[small];

    }

 

    public Integer compute() {

        if (n <= 10) {

            return compute(n);

        }

        Fibonacci f1 = new Fibonacci(n - 1);

        Fibonacci f2 = new Fibonacci(n - 2);

        System.out.println("fork new thread for " + (n - 1));

        f1.fork();

        System.out.println("fork new thread for " + (n - 2));

        f2.fork();

        return f1.join() + f2.join();

    }

}

 

class ConcurrentPrint extends RecursiveAction {

    protected void compute() {

        TaskBarrier b = new TaskBarrier() {

            protected boolean terminate(int cycle, int registeredParties) {

                System.out.println("Cycle is " + cycle + ";"

                        + registeredParties + " parties");

                return cycle >= 10;

            }

        };

        int n = 3;

        CyclicAction[] actions = new CyclicAction[n];

        for (int i = 0; i < n; ++i) {

            final int index = i;

            actions[i] = new CyclicAction(b) {

                protected void compute() {

                    System.out.println("I'm working " + getCycle() + " "

                            + index);

                    try {

                        Thread.sleep(500);

                    } catch (InterruptedException e) {

                        e.printStackTrace();

                    }

                }

            };

        }

        for (int i = 0; i < n; ++i)

            actions[i].fork();

        for (int i = 0; i < n; ++i)

            actions[i].join();

    }

}

 

public class TestForkJoin {

    @Test

    public void testBarrier () throws InterruptedException, ExecutionException {

        System.out.println("\ntesting Task Barrier ...");

        ForkJoinTask fjt = new ConcurrentPrint();

        ForkJoinPool fjpool = new ForkJoinPool(4);

        fjpool.submit(fjt);

        fjpool.shutdown();

    }

 

    @Test

    public void testFibonacci () throws InterruptedException, ExecutionException {

        System.out.println("\ntesting Fibonacci ...");

        final int num = 14; //For demo only

        ForkJoinTask<Integer> fjt = new Fibonacci(num);

        ForkJoinPool fjpool = new ForkJoinPool();

        Future<Integer> result = fjpool.submit(fjt);

 

        // do something

        System.out.println("Fibonacci(" + num + ") = " + result.get());

    }

}

运行以上代码,我们可以得到以下结果:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

3

testing Task Barrier ...

I'm working 0 2

I'm working 0 0

I'm working 0 1

Cycle is 0; 3 parties

I'm working 1 2

I'm working 1 0

I'm working 1 1

Cycle is 1; 3 parties

I'm working 2 0

I'm working 2 1

I'm working 2 2

Cycle is 2; 3 parties

I'm working 3 0

I'm working 3 2

I'm working 3 1

Cycle is 3; 3 parties

I'm working 4 2

I'm working 4 0

I'm working 4 1

Cycle is 4; 3 parties

I'm working 5 1

I'm working 5 0

I'm working 5 2

Cycle is 5; 3 parties

I'm working 6 0

I'm working 6 2

I'm working 6 1

Cycle is 6; 3 parties

I'm working 7 2

I'm working 7 0

I'm working 7 1

Cycle is 7; 3 parties

I'm working 8 1

I'm working 8 0

I'm working 8 2

Cycle is 8; 3 parties

I'm working 9 0

I'm working 9 2

 

testing Fibonacci ...

fork new thread for 13

fork new thread for 12

fork new thread for 11

fork new thread for 10

fork new thread for 12

fork new thread for 11

fork new thread for 10

fork new thread for 9

fork new thread for 10

fork new thread for 9

fork new thread for 11

fork new thread for 10

fork new thread for 10

fork new thread for 9

Fibonacci(14) = 610

结论

从以上的例子中可以看到,通过使用 Fork/Join 模式,软件开发人员能够方便地利用多核平台的计算能力。尽管还没有做到对软件开发人员完全透明,Fork/Join 模式已经极大地简化了编写并发程序的琐碎工作。对于符合 Fork/Join 模式的应用,软件开发人员不再需要处理各种并行相关事务,例如同步、通信等,以难以调试而闻名的死锁和 data race 等错误也就不会出现,提升了思考问题的层次。你可以把 Fork/Join 模式看作并行版本的 Divide and Conquer 策略,仅仅关注如何划分任务和组合中间结果,将剩下的事情丢给 Fork/Join 框架。

在实际工作中利用 Fork/Join 模式,可以充分享受多核平台为应用带来的免费午餐。

关于FK中的工作窃取算法

fork-join 框架通过一种称作工作窃取(work stealing) 的技术减少了工作队列的争用情况。每个工作线程都有自己的工作队列,这是使用双端队列(或者叫做 deque)来实现的(Java 6 在类库中添加了几种 deque 实现,包括 ArrayDeque 和 LinkedBlockingDeque)。当一个任务划分一个新线程时,它将自己推到 deque 的头部。当一个任务执行与另一个未完成任务的合并操作时,它会将另一个任务推到队列头部并执行,而不会休眠以等待另一任务完成(像 Thread.join() 的操作一样)。当线程的任务队列为空,它将尝试从另一个线程的 deque 的尾部 窃取另一个任务。

可以使用标准队列实现工作窃取,但是与标准队列相比,deque 具有两方面的优势:减少争用和窃取。因为只有工作线程会访问自身的 deque 的头部,deque 头部永远不会发生争用;因为只有当一个线程空闲时才会访问 deque 的尾部,所以也很少存在线程的 deque 尾部的争用(在 fork-join 框架中结合 deque 实现会使这些访问模式进一步减少协调成本)。跟传统的基于线程池的方法相比,减少争用会大大降低同步成本。此外,这种方法暗含的后进先出(last-in-first-out,LIFO)任务排队机制意味着最大的任务排在队列的尾部,当另一个线程需要窃取任务时,它将得到一个能够分解成多个小任务的任务,从而避免了在未来窃取任务。因此,工作窃取实现了合理的负载平衡,无需进行协调并且将同步成本降到了最小。

相关主题

  • 阅读文章“The Free Lunch Is Over: A Fundamental Turn Toward Concurrency in Software”:了解为什么从现在开始每个严肃的软件工作者都应该了解并行编程方法。

  • 阅读 Doug Lea 的文章“A Java Fork/Join Framework”:了解 Fork/Join 模式的实现机制和执行性能。http://gee.cs.oswego.edu/dl/papers/fj.pdf

  • 阅读 developerWorks 文章“驯服 Tiger:并发集合”:了解如何使用并行 Collection 库。

  • 阅读 developerWorks 文章“Java 理论与实践:非阻塞算法简介”:介绍了 JDK 5 在并行方面的重要增强以及在 JDK5 平台上如何实现非阻塞算法的一般介绍。

  • 书籍“Java Concurrency in Practice”:介绍了大量的并行编程技巧、反模式、可行的解决方案等,它对于 JDK 5 中的新特性也有详尽的介绍。

  • 访问 Doug Lea 的 JSR 166 站点获得最新的源代码。

  • 从 Sun 公司 网站下载 Java SE 6。

Guess you like

Origin blog.csdn.net/universsky2015/article/details/109349094