ThreadPoolを使用したJava 8並列ストリーム

並列ストリームを実行すると、ForkJoinPool.commonPool()他のすべての並列ストリームによって共有される共通フォーク結合プール()で実行されます。
特定の数のスレッドで構築された別の専用スレッドプールでコードを並列実行したい場合があります。たとえば、使用する場合、myCollection.parallelStream()それを行うための便利な方法はありません。その目的に使用できる
小さな便利なユーティリティ(ThreadExecutorクラス)を書きました
次の例では、ThreadExecutorユーティリティを使用して、長い配列に計算された数値を入力する簡単な使用法を示します。各数値は、(共通プールではなく)フォーク結合プールのスレッドで計算されます。
スレッドプールの作成は、ユーティリティによって行われます。プール内のスレッド数を制御します(int parallelism)、プール内のスレッドの名前(スレッドダンプの調査時に役立ちます)、およびオプションでタイムアウト制限。
私はJUnit 5を使用してテストしました。これは、テストメソッドの時間を計る良い方法を提供します(TimingExtensionを参照)。

一次ソースコードは、GitHubのhttps://github.com/igalhaddad/thread-executorから入手でき
ます。

ThreadExecutor Utilityクラス

import com.google.common.base.Throwables;
import com.google.common.util.concurrent.ExecutionError;
import com.google.common.util.concurrent.UncheckedExecutionException;
import com.google.common.util.concurrent.UncheckedTimeoutException;

import java.time.Duration;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.function.Function;

public class ThreadExecutor {
    public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, Function<T, R> parallelStream) {
        return execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);
    }

    public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Function<T, R> parallelStream) {
        if (timeout < 0)
            throw new IllegalArgumentException("Invalid timeout " + timeout);
        // see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)
        ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);
        Future<R> future = threadPool.submit(() -> parallelStream.apply(source));
        try {
            return timeout == 0 ? future.get() : future.get(timeout, unit);
        } catch (ExecutionException e) {
            future.cancel(true);
            threadPool.shutdownNow();
            Throwable cause = e.getCause();
            if (cause instanceof Error)
                throw new ExecutionError((Error) cause);
            throw new UncheckedExecutionException(cause);
        } catch (TimeoutException e) {
            future.cancel(true);
            threadPool.shutdownNow();
            throw new UncheckedTimeoutException(e);
        } catch (Throwable t) {
            future.cancel(true);
            threadPool.shutdownNow();
            Throwables.throwIfUnchecked(t);
            throw new RuntimeException(t);
        } finally {
            threadPool.shutdown();
        }
    }

    public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, Consumer<T> parallelStream) {
        execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);
    }

    public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Consumer<T> parallelStream) {
        if (timeout < 0)
            throw new IllegalArgumentException("Invalid timeout " + timeout);
        // see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)
        ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);
        CompletableFuture<Void> future = null;
        try {
            Runnable task = () -> parallelStream.accept(source);
            if (timeout == 0) {
                future = CompletableFuture.runAsync(task, threadPool);
                future.get();
                threadPool.shutdown();
            } else {
                threadPool.execute(task);
                threadPool.shutdown();
                if (!threadPool.awaitTermination(timeout, unit))
                    throw new TimeoutException("Timed out after: " + Duration.of(timeout, unit.toChronoUnit()));
            }
        } catch (TimeoutException e) {
            threadPool.shutdownNow();
            throw new UncheckedTimeoutException(e);
        } catch (ExecutionException e) {
            future.cancel(true);
            threadPool.shutdownNow();
            Throwable cause = e.getCause();
            if (cause instanceof Error)
                throw new ExecutionError((Error) cause);
            throw new UncheckedExecutionException(cause);
        } catch (Throwable t) {
            threadPool.shutdownNow();
            Throwables.throwIfUnchecked(t);
            throw new RuntimeException(t);
        }
    }
}

NamedForkJoinWorkerThreadFactoryクラス

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;

public class NamedForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
    private AtomicInteger counter = new AtomicInteger(0);
    private final String name;
    private final boolean daemon;

    public NamedForkJoinWorkerThreadFactory(String name, boolean daemon) {
        this.name = name;
        this.daemon = daemon;
    }

    public NamedForkJoinWorkerThreadFactory(String name) {
        this(name, false);
    }

    @Override
    public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
        ForkJoinWorkerThread t = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
        t.setName(name + counter.incrementAndGet());
        t.setDaemon(daemon);
        return t;
    }
}

ThreadExecutorTests JUnitクラス

import static org.junit.jupiter.api.Assertions.*;

import com.github.igalhaddad.threadexecutor.timing.TimingExtension;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.extension.ExtendWith;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;

@ExtendWith(TimingExtension.class)
@TestMethodOrder(OrderAnnotation.class)
@DisplayName("Test ThreadExecutor utility")
public class ThreadExecutorTests {
    private static final Logger logger = Logger.getLogger(ThreadExecutorTests.class.getName());
    private static final int SEQUENCE_LENGTH = 1000000;

    private static List<long[]> fibonacciSequences = new ArrayList<>();
    private long[] fibonacciSequence;

    @BeforeAll
    static void initAll() {
        logger.info(() -> "Number of available processors: " + Runtime.getRuntime().availableProcessors());
    }

    @BeforeEach
    void init() {
        this.fibonacciSequence = new long[SEQUENCE_LENGTH];
        fibonacciSequences.add(fibonacciSequence);
    }

    @AfterEach
    void tearDown() {
        int firstX = 10;
        logger.info(() -> "First " + firstX + " numbers: " + Arrays.stream(this.fibonacciSequence)
                .limit(firstX)
                .mapToObj(Long::toString)
                .collect(Collectors.joining(",", "[", ",...]")));
        int n = SEQUENCE_LENGTH - 1; // Last number
        assertFn(n);
        assertFn(n / 2);
        assertFn(n / 3);
        assertFn(n / 5);
        assertFn(n / 10);
        assertFn((n / 3) * 2);
        assertFn((n / 5) * 4);
    }

    private void assertFn(int n) {
        assertEquals(fibonacciSequence[n - 1] + fibonacciSequence[n - 2], fibonacciSequence[n]);
    }

    @AfterAll
    static void tearDownAll() {
        long[] fibonacciSequence = fibonacciSequences.iterator().next();
        for (int i = 1; i < fibonacciSequences.size(); i++) {
            assertArrayEquals(fibonacciSequence, fibonacciSequences.get(i));
        }
    }

    @Test
    @Order(1)
    @DisplayName("Calculate Fibonacci sequence sequentially")
    public void testSequential() {
        logger.info(() -> "Running sequentially. No parallelism");
        for (int i = 0; i < fibonacciSequence.length; i++) {
            fibonacciSequence[i] = Fibonacci.compute(i);
        }
    }

    @Test
    @Order(2)
    @DisplayName("Calculate Fibonacci sequence concurrently on all processors")
    public void testParallel1() {
        testParallel(Runtime.getRuntime().availableProcessors());
    }

    @Test
    @Order(3)
    @DisplayName("Calculate Fibonacci sequence concurrently on half of the processors")
    public void testParallel2() {
        testParallel(Math.max(1, Runtime.getRuntime().availableProcessors() / 2));
    }

    private void testParallel(int parallelism) {
        logger.info(() -> String.format("Running in parallel on %d processors", parallelism));
        ThreadExecutor.execute(parallelism, "FibonacciTask", fibonacciSequence,
                (long[] fibonacciSequence) -> Arrays.parallelSetAll(fibonacciSequence, Fibonacci::compute)
        );
    }

    static class Fibonacci {
        public static long compute(int n) {
            if (n <= 1)
                return n;
            long a = 0, b = 1;
            long sum = a + b; // for n == 2
            for (int i = 3; i <= n; i++) {
                a = sum; // using `a` for temporary storage
                sum += b;
                b = a;
            }
            return sum;
        }
    }
}

testParallel(int parallelism)メソッドに注意してください。このメソッドは、スレッドエグゼキューターユーティリティを使用して、提供されたスレッドの数で構成される個別の専用スレッドプールで並列ストリームを実行します。各スレッドには「FibonacciTask」という名前が付けられ、シリアル番号(「FibonacciTask3」など)と連結されます。一緒に。名前付きスレッドは、ForkJoinWorkerThread Factoryという名前のクラスからのものです。たとえば、testParallel2()ブレークポイントのテストメソッドフィボナッチ計算メソッドを中断すると、「FibonacciTask1-6」という名前の6つのスレッドが表示されました。これはそのうちの1つです。

"FibonacciTask3 @ 2715" prio = 5 tid = 0x22 nid = NA実行可能
java.lang.Thread.State:RUNNABLE

  at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$Fibonacci.compute(ThreadExecutorTests.java:103)
  at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$$Lambda$366.1484420181.applyAsLong(Unknown Source:-1)
  at java.util.Arrays.lambda$parallelSetAll$2(Arrays.java:5408)
  at java.util.Arrays$$Lambda$367.864455139.accept(Unknown Source:-1)
  at java.util.stream.ForEachOps$ForEachOp$OfInt.accept(ForEachOps.java:204)
  at java.util.stream.Streams$RangeIntSpliterator.forEachRemaining(Streams.java:104)
  at java.util.Spliterator$OfInt.forEachRemaining(Spliterator.java:699)
  at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:484)
  at java.util.stream.ForEachOps$ForEachTask.compute(ForEachOps.java:290)
  at java.util.concurrent.CountedCompleter.exec(CountedCompleter.java:746)
  at java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:290)
  at java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1016)
  at java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1665)
  at java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1598)
  at java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:177)

testParallel(int parallelism)メソッドはArrays.parallelSetAllを実行します。実際、これは単純な並列ストリームであり、Javaソースコードに実装されています。

    public static void parallelSetAll(long[] array, IntToLongFunction generator) {
        Objects.requireNonNull(generator);
        IntStream.range(0, array.length).parallel().forEach(i -> { array[i] = generator.applyAsLong(i); });
    }

テストメソッドのタイミングを見てみましょう⏱:

試験結果
あなたが出力で見ることができるように:

  1. testSequential()テストメソッドは148622ミリ秒(並列処理なし)かかります。testParallel1()テストメソッドには16995ミリ秒(12プロセッサを並列に)かかります。testParallel2()テストメソッドは31152ミリ秒かかります(6つのプロセッサが並列)。

3つのテストメソッドはすべて同じタスクを実行します。つまり、1,000,000桁の長さのフィボナッチ数列を計算します。

から:https://dev.to//igalhaddad/java-8-parallel-stream-with-threadpool-32kd

公開元の記事0件 ・いい ね0件 訪問418

おすすめ

転載: blog.csdn.net/cunxiedian8614/article/details/105690002