Al ejecutar una secuencia paralela, se ejecuta en Common Fork Join Pool ( ForkJoinPool.commonPool()
), compartida por todas las demás secuencias paralelas.
A veces queremos ejecutar código en paralelo en un grupo de subprocesos dedicado separado, construido con un número específico de subprocesos. Cuando se usa, por ejemplo, myCollection.parallelStream()
no nos da una forma conveniente de hacerlo.
Escribí una pequeña utilidad práctica ( ThreadExecutor
clase) que se puede usar para ese propósito.
En el siguiente ejemplo, demostraré el uso simple de la ThreadExecutor
utilidad para llenar una matriz larga con números calculados, cada número se calcula en un subproceso en un grupo de bifurcación (no el grupo común).
La creación del grupo de subprocesos la realiza la utilidad. Controlamos el número de subprocesos en el grupo (int parallelism
), el nombre de los subprocesos en el grupo (útil cuando se investiga el volcado de subprocesos) y, opcionalmente, un límite de tiempo de espera.
Lo probé con JUnit 5, que proporciona una buena manera de cronometrar los métodos de prueba (ver TimingExtension ).
Todo el código fuente está disponible en GitHub en:
https://github.com/igalhaddad/thread-executor
Clase de utilidad ThreadExecutor
import com.google.common.base.Throwables;
import com.google.common.util.concurrent.ExecutionError;
import com.google.common.util.concurrent.UncheckedExecutionException;
import com.google.common.util.concurrent.UncheckedTimeoutException;
import java.time.Duration;
import java.util.concurrent.*;
import java.util.function.Consumer;
import java.util.function.Function;
public class ThreadExecutor {
public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, Function<T, R> parallelStream) {
return execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);
}
public static <T, R> R execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Function<T, R> parallelStream) {
if (timeout < 0)
throw new IllegalArgumentException("Invalid timeout " + timeout);
// see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)
ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);
Future<R> future = threadPool.submit(() -> parallelStream.apply(source));
try {
return timeout == 0 ? future.get() : future.get(timeout, unit);
} catch (ExecutionException e) {
future.cancel(true);
threadPool.shutdownNow();
Throwable cause = e.getCause();
if (cause instanceof Error)
throw new ExecutionError((Error) cause);
throw new UncheckedExecutionException(cause);
} catch (TimeoutException e) {
future.cancel(true);
threadPool.shutdownNow();
throw new UncheckedTimeoutException(e);
} catch (Throwable t) {
future.cancel(true);
threadPool.shutdownNow();
Throwables.throwIfUnchecked(t);
throw new RuntimeException(t);
} finally {
threadPool.shutdown();
}
}
public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, Consumer<T> parallelStream) {
execute(parallelism, forkJoinWorkerThreadName, source, 0, null, parallelStream);
}
public static <T> void execute(int parallelism, String forkJoinWorkerThreadName, T source, long timeout, TimeUnit unit, Consumer<T> parallelStream) {
if (timeout < 0)
throw new IllegalArgumentException("Invalid timeout " + timeout);
// see java.util.concurrent.Executors.newWorkStealingPool(int parallelism)
ExecutorService threadPool = new ForkJoinPool(parallelism, new NamedForkJoinWorkerThreadFactory(forkJoinWorkerThreadName), null, true);
CompletableFuture<Void> future = null;
try {
Runnable task = () -> parallelStream.accept(source);
if (timeout == 0) {
future = CompletableFuture.runAsync(task, threadPool);
future.get();
threadPool.shutdown();
} else {
threadPool.execute(task);
threadPool.shutdown();
if (!threadPool.awaitTermination(timeout, unit))
throw new TimeoutException("Timed out after: " + Duration.of(timeout, unit.toChronoUnit()));
}
} catch (TimeoutException e) {
threadPool.shutdownNow();
throw new UncheckedTimeoutException(e);
} catch (ExecutionException e) {
future.cancel(true);
threadPool.shutdownNow();
Throwable cause = e.getCause();
if (cause instanceof Error)
throw new ExecutionError((Error) cause);
throw new UncheckedExecutionException(cause);
} catch (Throwable t) {
threadPool.shutdownNow();
Throwables.throwIfUnchecked(t);
throw new RuntimeException(t);
}
}
}
Clase NamedForkJoinWorkerThreadFactory
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinWorkerThread;
import java.util.concurrent.atomic.AtomicInteger;
public class NamedForkJoinWorkerThreadFactory implements ForkJoinPool.ForkJoinWorkerThreadFactory {
private AtomicInteger counter = new AtomicInteger(0);
private final String name;
private final boolean daemon;
public NamedForkJoinWorkerThreadFactory(String name, boolean daemon) {
this.name = name;
this.daemon = daemon;
}
public NamedForkJoinWorkerThreadFactory(String name) {
this(name, false);
}
@Override
public ForkJoinWorkerThread newThread(ForkJoinPool pool) {
ForkJoinWorkerThread t = ForkJoinPool.defaultForkJoinWorkerThreadFactory.newThread(pool);
t.setName(name + counter.incrementAndGet());
t.setDaemon(daemon);
return t;
}
}
ThreadExecutorTests JUnit clase
import static org.junit.jupiter.api.Assertions.*;
import com.github.igalhaddad.threadexecutor.timing.TimingExtension;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.MethodOrderer.OrderAnnotation;
import org.junit.jupiter.api.extension.ExtendWith;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@ExtendWith(TimingExtension.class)
@TestMethodOrder(OrderAnnotation.class)
@DisplayName("Test ThreadExecutor utility")
public class ThreadExecutorTests {
private static final Logger logger = Logger.getLogger(ThreadExecutorTests.class.getName());
private static final int SEQUENCE_LENGTH = 1000000;
private static List<long[]> fibonacciSequences = new ArrayList<>();
private long[] fibonacciSequence;
@BeforeAll
static void initAll() {
logger.info(() -> "Number of available processors: " + Runtime.getRuntime().availableProcessors());
}
@BeforeEach
void init() {
this.fibonacciSequence = new long[SEQUENCE_LENGTH];
fibonacciSequences.add(fibonacciSequence);
}
@AfterEach
void tearDown() {
int firstX = 10;
logger.info(() -> "First " + firstX + " numbers: " + Arrays.stream(this.fibonacciSequence)
.limit(firstX)
.mapToObj(Long::toString)
.collect(Collectors.joining(",", "[", ",...]")));
int n = SEQUENCE_LENGTH - 1; // Last number
assertFn(n);
assertFn(n / 2);
assertFn(n / 3);
assertFn(n / 5);
assertFn(n / 10);
assertFn((n / 3) * 2);
assertFn((n / 5) * 4);
}
private void assertFn(int n) {
assertEquals(fibonacciSequence[n - 1] + fibonacciSequence[n - 2], fibonacciSequence[n]);
}
@AfterAll
static void tearDownAll() {
long[] fibonacciSequence = fibonacciSequences.iterator().next();
for (int i = 1; i < fibonacciSequences.size(); i++) {
assertArrayEquals(fibonacciSequence, fibonacciSequences.get(i));
}
}
@Test
@Order(1)
@DisplayName("Calculate Fibonacci sequence sequentially")
public void testSequential() {
logger.info(() -> "Running sequentially. No parallelism");
for (int i = 0; i < fibonacciSequence.length; i++) {
fibonacciSequence[i] = Fibonacci.compute(i);
}
}
@Test
@Order(2)
@DisplayName("Calculate Fibonacci sequence concurrently on all processors")
public void testParallel1() {
testParallel(Runtime.getRuntime().availableProcessors());
}
@Test
@Order(3)
@DisplayName("Calculate Fibonacci sequence concurrently on half of the processors")
public void testParallel2() {
testParallel(Math.max(1, Runtime.getRuntime().availableProcessors() / 2));
}
private void testParallel(int parallelism) {
logger.info(() -> String.format("Running in parallel on %d processors", parallelism));
ThreadExecutor.execute(parallelism, "FibonacciTask", fibonacciSequence,
(long[] fibonacciSequence) -> Arrays.parallelSetAll(fibonacciSequence, Fibonacci::compute)
);
}
static class Fibonacci {
public static long compute(int n) {
if (n <= 1)
return n;
long a = 0, b = 1;
long sum = a + b; // for n == 2
for (int i = 3; i <= n; i++) {
a = sum; // using `a` for temporary storage
sum += b;
b = a;
}
return sum;
}
}
}
Tenga en cuenta el método testParallel (paralelismo int). Este método utiliza una utilidad de ejecución de subprocesos para ejecutar flujos paralelos en un grupo de subprocesos dedicado independiente que consta del número de subprocesos proporcionados, donde cada subproceso se denomina "FibonacciTask" y se concatena con un número de serie (por ejemplo, "FibonacciTask3") Juntos El hilo nombrado proviene de una clase llamada ForkJoinWorkerThread Factory. Por ejemplo, suspendí el método de prueba Método de cálculo de Fibonacci para el punto de corte testParallel2 (), y vi 6 hilos llamados "FibonacciTask1-6". Este es uno de ellos:
"FibonacciTask3 @ 2715" prio = 5 tid = 0x22 nid = NA ejecutable
java.lang.Thread.State: RUNNABLE
at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$Fibonacci.compute(ThreadExecutorTests.java:103) at com.github.igalhaddad.threadexecutor.util.ThreadExecutorTests$$Lambda$366.1484420181.applyAsLong(Unknown Source:-1) at java.util.Arrays.lambda$parallelSetAll$2(Arrays.java:5408) at java.util.Arrays$$Lambda$367.864455139.accept(Unknown Source:-1) at java.util.stream.ForEachOps$ForEachOp$OfInt.accept(ForEachOps.java:204) at java.util.stream.Streams$RangeIntSpliterator.forEachRemaining(Streams.java:104) at java.util.Spliterator$OfInt.forEachRemaining(Spliterator.java:699) at java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:484) at java.util.stream.ForEachOps$ForEachTask.compute(ForEachOps.java:290) at java.util.concurrent.CountedCompleter.exec(CountedCompleter.java:746) at java.util.concurrent.ForkJoinTask.doExec(ForkJoinTask.java:290) at java.util.concurrent.ForkJoinPool$WorkQueue.topLevelExec(ForkJoinPool.java:1016) at java.util.concurrent.ForkJoinPool.scan(ForkJoinPool.java:1665) at java.util.concurrent.ForkJoinPool.runWorker(ForkJoinPool.java:1598) at java.util.concurrent.ForkJoinWorkerThread.run(ForkJoinWorkerThread.java:177)
El método testParallel (int parallelism) ejecuta Arrays.parallelSetAll. De hecho, esto es solo una secuencia paralela simple, como se implementa en el código fuente de Java:
public static void parallelSetAll(long[] array, IntToLongFunction generator) {
Objects.requireNonNull(generator);
IntStream.range(0, array.length).parallel().forEach(i -> { array[i] = generator.applyAsLong(i); });
}
Ahora veamos el tiempo de los métodos de prueba ⏱:
- El método de prueba testSequential () toma 148622 milisegundos (sin paralelismo). El método de prueba testParallel1 () toma 16995 milisegundos (12 procesadores en paralelo). El método de prueba testParallel2 () toma 31152 milisegundos (6 procesadores en paralelo).
Los tres métodos de prueba realizan la misma tarea, que es calcular la secuencia de Fibonacci con una longitud de 1,000,000 de números.
de: https://dev.to//igalhaddad/java-8-parallel-stream-with-threadpool-32kd