带有ThreadPool的Java 8并行流

When executing a parallel stream, it runs in the Common Fork Join Pool (ForkJoinPool.commonPool()), shared by all other parallel streams.
Sometimes we want to execute code in parallel on a separate dedicated thread pool, constructed with a specific number of threads. When using, for example, myCollection.parallelStream() it doesn't give us a convenient way to do that.
I wrote a small handy utility (ThreadExecutor class) that can be used for that purpose.
In the following example, I will demonstrate simple usage of the ThreadExecutor utility to fill a long array with calculated numbers, each number is calculated in a thread on a Fork Join Pool (not the common pool).
The creation of the thread pool is done by the utility. We control the number of threads in the pool (int parallelism), the name of the threads in the pool (useful when investigating threads dump), and optionally a timeout limit.
I tested it with JUnit 5 which provides a nice way to time the test methods (see TimingExtension).

一种ll source code is available in GitHub at:
https://github.com/igalhaddad/thread-executor

ThreadExecutor Utility class

import com.google.common.base.Throwables;
import com.google.common.util.concurrent.ExecutionError;
import com.google.common.util.concurrent.UncheckedExecutionException;
import com.google.common.util.concurrent.UncheckedTimeoutException;

import java.time.Duration;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.function.Function;

public class ThreadExecutor {
    public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, Function<T, R> parallelStream) {
        return execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);
    }

    public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Function<T, R> parallelStream) {
        if (timeout < 0)
            throw new IllegalArgumentException("Invalid timeout " + timeout);
        // see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)
        ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);
        Future<R> future = threadPool.submit(() -> parallelStream.apply(source));
        try {
            return timeout == 0 ? future.get() : future.get(timeout, unit);
        } catch (ExecutionException e) {
            future.cancel(true);
            threadPool.shutdownNow();
            Throwable cause = e.getCause();
            if (cause instanceof Error)
                throw new ExecutionError((Error) cause);
            throw new UncheckedExecutionException(cause);
        } catch (TimeoutException e) {
            future.cancel(true);
            threadPool.shutdownNow();
            throw new UncheckedTimeoutException(e);
        } catch (Throwable t) {
            future.cancel(true);
            threadPool.shutdownNow();
            Throwables.throwIfUnchecked(t);
            throw new RuntimeException(t);
        } finally {
            threadPool.shutdown();
        }
    }

    public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, Consumer<T> parallelStream) {
        execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);
    }

    public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Consumer<T> parallelStream) {
        if (timeout < 0)
            throw new IllegalArgumentException("Invalid timeout " + timeout);
        // see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)
        ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);
        CompletableFuture<Void> future = null;
        try {
            Runnable task = () -> parallelStream.accept(source);
            if (timeout == 0) {
                future = CompletableFuture.runAsync(task, threadPool);
                future.get();
                threadPool.shutdown();
            } else {
                threadPool.execute(task);
                threadPool.shutdown();
                if (!threadPool.awaitTermination(timeout, unit))
                    throw new TimeoutException("Timed out after: " + Duration.of(timeout, unit.toChronoUnit()));
            }
        } catch (TimeoutException e) {
            threadPool.shutdownNow();
            throw new UncheckedTimeoutException(e);
        } catch (ExecutionException e) {
            future.cancel(true);
            threadPool.shutdownNow();
            Throwable cause = e.getCause();
            if (cause instanceof Error)
                throw new ExecutionError((Error) cause);
            throw new UncheckedExecutionException(cause);
        } catch (Throwable t) {
            threadPool.shutdownNow();
            Throwables.throwIfUnchecked(t);
            throw new RuntimeException(t);
        }
    }
}

NamedForkJoinWorkerThreadFactory class

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;

public class NamedForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
    private AtomicInteger counter = new AtomicInteger(0);
    private final String name;
    private final boolean daemon;

    public NamedForkJoinWorkerThreadFactory(String name, boolean daemon) {
        this.name = name;
        this.daemon = daemon;
    }

    public NamedForkJoinWorkerThreadFactory(String name) {
        this(name, false);
    }

    @Override
    public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
        ForkJoinWorkerThread t = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
        t.setName(name + counter.incrementAndGet());
        t.setDaemon(daemon);
        return t;
    }
}

ThreadExecutorTests JUnit class

import static org.junit.jupiter.api.Assertions.*;

import com.github.igalhaddad.threadexecutor.timing.TimingExtension;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.extension.ExtendWith;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@ExtendWith(TimingExtension.class)
@TestMethodOrder(OrderAnnotation.class)
@DisplayName("Test ThreadExecutor utility")
public class ThreadExecutorTests {
    private static final Logger logger = Logger.getLogger(ThreadExecutorTests.class.getName());
    private static final int SEQUENCE_LENGTH = 1000000;

    private static List<long[]> fibonacciSequences = new ArrayList<>();
    private long[] fibonacciSequence;

    @BeforeAll
    static void initAll() {
        logger.info(() -> "Number of available processors: " + Runtime.getRuntime().availableProcessors());
    }

    @BeforeEach
    void init() {
        this.fibonacciSequence = new long[SEQUENCE_LENGTH];
        fibonacciSequences.add(fibonacciSequence);
    }

    @AfterEach
    void tearDown() {
        int firstX = 10;
        logger.info(() -> "First " + firstX + " numbers: " + Arrays.stream(this.fibonacciSequence)
                .limit(firstX)
                .mapToObj(Long::toString)
                .collect(Collectors.joining(",", "[", ",...]")));
        int n = SEQUENCE_LENGTH - 1; // Last number
        assertFn(n);
        assertFn(n / 2);
        assertFn(n / 3);
        assertFn(n / 5);
        assertFn(n / 10);
        assertFn((n / 3) * 2);
        assertFn((n / 5) * 4);
    }

    private void assertFn(int n) {
        assertEquals(fibonacciSequence[n - 1] + fibonacciSequence[n - 2], fibonacciSequence[n]);
    }

    @AfterAll
    static void tearDownAll() {
        long[] fibonacciSequence = fibonacciSequences.iterator().next();
        for (int i = 1; i < fibonacciSequences.size(); i++) {
            assertArrayEquals(fibonacciSequence, fibonacciSequences.get(i));
        }
    }

    @Test
    @Order(1)
    @DisplayName("Calculate Fibonacci sequence sequentially")
    public void testSequential() {
        logger.info(() -> "Running sequentially. No parallelism");
        for (int i = 0; i < fibonacciSequence.length; i++) {
            fibonacciSequence[i] = Fibonacci.compute(i);
        }
    }

    @Test
    @Order(2)
    @DisplayName("Calculate Fibonacci sequence concurrently on all processors")
    public void testParallel1() {
        testParallel(Runtime.getRuntime().availableProcessors());
    }

    @Test
    @Order(3)
    @DisplayName("Calculate Fibonacci sequence concurrently on half of the processors")
    public void testParallel2() {
        testParallel(Math.max(1, Runtime.getRuntime().availableProcessors() / 2));
    }

    private void testParallel(int parallelism) {
        logger.info(() -> String.format("Running in parallel on %d processors", parallelism));
        ThreadExecutor.execute(parallelism, "FibonacciTask", fibonacciSequence,
                (long[] fibonacciSequence) -> Arrays.parallelSetAll(fibonacciSequence, Fibonacci::compute)
        );
    }

    static class Fibonacci {
        public static long compute(int n) {
            if (n <= 1)
                return n;
            long a = 0, b = 1;
            long sum = a + b; // for n == 2
            for (int i = 3; i <= n; i++) {
                a = sum; // using `a` for temporary storage
                sum += b;
                b = a;
            }
            return sum;
        }
    }
}

注意testParallel(int并行性)方法。 该方法使用线程执行器实用程序,以在由所提供的线程数组成的单独的专用线程池上执行并行流,其中每个线程被命名为“ FibonacciTask”,并与序列号(例如“ FibonacciTask3”)串联在一起。 命名线程来自名为ForkJoinWorkerThread Factory类。 例如,我暂停了testParallel2()断点的测试方法斐波那契计算方法,我看到了6个名为“ FibonacciTask1-6”的线程。 这是其中之一:

"FibonacciTask3@2715" prio=5 tid=0x22 nid=NA runnable
java.lang.Thread.State: RUNNABLE

  at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$Fibonacci.compute(ThreadExecutorTests.java:103)
  at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$$Lambda$366.1484420181.applyAsLong(Unknown Source:-1)
  at java.util.Arrays.lambda$parallelSetAll$2(Arrays.java:5408)
  at java.util.Arrays$$Lambda$367.864455139.accept(Unknown Source:-1)
  at java.util.stream.ForEachOps$ForEachOp$OfInt.accept(ForEachOps.java:204)
  at java.util.stream.Streams$RangeIntSpliterator.forEachRemaining(Streams.java:104)
  at java.util.Spliterator$OfInt.forEachRemaining(Spliterator.java:699)
  at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:484)
  at java.util.stream.ForEachOps$ForEachTask.compute(ForEachOps.java:290)
  at java.util.concurrent.CountedCompleter.exec(CountedCompleter.java:746)
  at java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:290)
  at java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1016)
  at java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1665)
  at java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1598)
  at java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:177)

的testParallel(int并行性)方法执行Arrays.parallelSetAll实际上,这只是一个简单的并行流,如java源代码中所实现的:

    public static void parallelSetAll(long[] array, IntToLongFunction generator) {
        Objects.requireNonNull(generator);
        IntStream.range(0, array.length).parallel().forEach(i -> { array[i] = generator.applyAsLong(i); });
    }

Now lets see the test methods timing ⏱:

Test Results
As you can see in the output:

  1. testSequential()测试方法耗时148622毫秒(无并行性)。testParallel1()测试方法耗时16995毫秒(并行12个处理器)。testParallel2()测试方法耗时31152毫秒(并行6个处理器)。

这三种测试方法都执行相同的任务,即计算长度为1,000,000个数字的斐波那契数列。

from: https://dev.to//igalhaddad/java-8-parallel-stream-with-threadpool-32kd

发布了0 篇原创文章 · 获赞 0 · 访问量 418

猜你喜欢

转载自blog.csdn.net/cunxiedian8614/article/details/105690002