使用ForkJoin,计算PI

教儿子Python,感觉使用梦特卡罗方法计算PI的方法,他应该感兴趣,就写了一个。

import time
from random import random

count = 0
DARTS = 2**26
print(DARTS)

now = time.time()

def distancdByXY(x, y):
    return (x ** 2 + y ** 2)# ** 0.5

for i in range(0, DARTS):
    xRand, yRand = random(), random()

    if distancdByXY(xRand, yRand) <= 1.0:
        count = count + 1

pi = 4*(float(count)/float(DARTS))
print("pi=" + str(pi))
print('cost time=' + str(time.time() - now))

在我的机器上,运行一次,居然要花费一分多钟。

67108864
pi=3.1413973569869995
cost time=69.3828194141388

太慢了,正好让他比较一下Java和Python的性能差异。

一开始的Java版本是这样的。

    private static final int DARTS = (int)Math.pow(2, 26);

    private static double distancdByXY(double x, double y) {
        //return Math.pow(Math.pow(x, 2) + Math.pow(y, 2), 0.5);
        return x * x + y * y;
    }
    
    @Test
    public void testPI() {
        System.out.println("DARTS = " + DARTS);

        long now = System.currentTimeMillis();

        int count = loop(DARTS, new SplittableRandom());

        double pi = 4.0 * count / DARTS;

        System.out.println("pi = " + pi);
        System.out.println("cost time = " + (System.currentTimeMillis() - now));
    }

    private int loop(int times, SplittableRandom random) {
        int count = 0;
        for (int i=0; i<times; i++) {
            double x = random.nextDouble();
            double y = random.nextDouble();

            if (distancdByXY(x, y) <= 1) {
                count++;
            }
        }
        return count;
    }    

计算一次,果然快了很多(Java返回的是毫秒计时)。

DARTS = 67108864
pi = 3.141629636287689
cost time = 545

PI的求值,是CPU密集的计算,不用ForkJoinPool,实在是太浪费了。

    @Test
    public void forkPI() throws Exception {
        System.out.println("DARTS = " + DARTS);

        long now = System.currentTimeMillis();

        ForkJoinPool pool = ForkJoinPool.commonPool();
        System.out.println("pool = " + ForkJoinPool.getCommonPoolParallelism());

        SplittableRandom random = new SplittableRandom();

        ForkJoinTask<Integer> task = pool.submit(new CountTask(0, DARTS, random));

        Integer count = task.get();

        double pi = 4.0 * count / DARTS;
        System.out.println("pi = " + pi);
        System.out.println("time = " + (System.currentTimeMillis() - now));
    }

    private class CountTask extends RecursiveTask<Integer> {

        private static final int threshold = 256 * 32;
        private int start;
        private int end;
        private SplittableRandom random;

        CountTask(int start, int end, SplittableRandom random) {
            this.start = start;
            this.end = end;
            this.random = random;
        }

        @Override
        protected Integer compute() {
            int count = 0;

            boolean canCompute = (end - start) <= threshold;
            if(canCompute) {
                count += loop(end - start, random);
            } else {
                int middle = (start + end)/2;
                CountTask leftTask = new CountTask(start, middle, random.split());
                CountTask rightTask = new CountTask(middle + 1, end, random.split());

                leftTask.fork();
                rightTask.fork();

                count = leftTask.join() + rightTask.join();
            }

            return count;
        }
    }    

速度有所提高。

DARTS = 67108864
pool = 3
pi = 3.1412885189056396
time = 303

希望他能理解。

猜你喜欢

转载自blog.csdn.net/weixin_43364172/article/details/83960314
pi