教儿子Python,感觉使用梦特卡罗方法计算PI的方法,他应该感兴趣,就写了一个。
import time
from random import random
count = 0
DARTS = 2**26
print(DARTS)
now = time.time()
def distancdByXY(x, y):
return (x ** 2 + y ** 2)# ** 0.5
for i in range(0, DARTS):
xRand, yRand = random(), random()
if distancdByXY(xRand, yRand) <= 1.0:
count = count + 1
pi = 4*(float(count)/float(DARTS))
print("pi=" + str(pi))
print('cost time=' + str(time.time() - now))
在我的机器上,运行一次,居然要花费一分多钟。
67108864
pi=3.1413973569869995
cost time=69.3828194141388
太慢了,正好让他比较一下Java和Python的性能差异。
一开始的Java版本是这样的。
private static final int DARTS = (int)Math.pow(2, 26);
private static double distancdByXY(double x, double y) {
//return Math.pow(Math.pow(x, 2) + Math.pow(y, 2), 0.5);
return x * x + y * y;
}
@Test
public void testPI() {
System.out.println("DARTS = " + DARTS);
long now = System.currentTimeMillis();
int count = loop(DARTS, new SplittableRandom());
double pi = 4.0 * count / DARTS;
System.out.println("pi = " + pi);
System.out.println("cost time = " + (System.currentTimeMillis() - now));
}
private int loop(int times, SplittableRandom random) {
int count = 0;
for (int i=0; i<times; i++) {
double x = random.nextDouble();
double y = random.nextDouble();
if (distancdByXY(x, y) <= 1) {
count++;
}
}
return count;
}
计算一次,果然快了很多(Java返回的是毫秒计时)。
DARTS = 67108864
pi = 3.141629636287689
cost time = 545
PI的求值,是CPU密集的计算,不用ForkJoinPool,实在是太浪费了。
@Test
public void forkPI() throws Exception {
System.out.println("DARTS = " + DARTS);
long now = System.currentTimeMillis();
ForkJoinPool pool = ForkJoinPool.commonPool();
System.out.println("pool = " + ForkJoinPool.getCommonPoolParallelism());
SplittableRandom random = new SplittableRandom();
ForkJoinTask<Integer> task = pool.submit(new CountTask(0, DARTS, random));
Integer count = task.get();
double pi = 4.0 * count / DARTS;
System.out.println("pi = " + pi);
System.out.println("time = " + (System.currentTimeMillis() - now));
}
private class CountTask extends RecursiveTask<Integer> {
private static final int threshold = 256 * 32;
private int start;
private int end;
private SplittableRandom random;
CountTask(int start, int end, SplittableRandom random) {
this.start = start;
this.end = end;
this.random = random;
}
@Override
protected Integer compute() {
int count = 0;
boolean canCompute = (end - start) <= threshold;
if(canCompute) {
count += loop(end - start, random);
} else {
int middle = (start + end)/2;
CountTask leftTask = new CountTask(start, middle, random.split());
CountTask rightTask = new CountTask(middle + 1, end, random.split());
leftTask.fork();
rightTask.fork();
count = leftTask.join() + rightTask.join();
}
return count;
}
}
速度有所提高。
DARTS = 67108864
pool = 3
pi = 3.1412885189056396
time = 303
希望他能理解。