蒙特卡罗方法采样算法
蒙特卡罗方法(Monte Carlo Simulation
)是一种随机模拟(或者统计模拟)方法。
给定统计样本集,如何估计产生这个样本集的随机变量概率密度函数,是我们比较熟悉的概率密度估计问题。
求解概率密度估计问题的常用方法是最大似然估计、最大后验估计等。但是,我们思考概率密度估计问题的逆问题:给定一个概率分布p(x)
,如何让计算机生成满足这个概率分布的样本。
这个问题就是统计模拟中研究的重要问题–采样(Sampling
)。本文将重点介绍其中两种重要的采样算法:MCMC(Markov Chain Monte Carlo)
算法和Gibbs Sampling
算法。
Sampling
一般而言均匀分布Uniform(0,1)
的样本是相对容易生成的。 通过线性同余发生器可以生成伪随机数,我们用确定性算法生成[0,1]
之间的伪随机数序列后,
这些序列的各种统计指标和均匀分布Uniform(0,1)
的理论计算结果非常接近。这样的伪随机序列就有比较好的统计性质,可以被当成真实的随机数使用。线性同余随机数生成器如下:
式中a
,c
,m
是数学推导出的合适的常数。这种算法产生的下一个随机数完全依赖当前的随机数,当随机数序列足够大的时候,随机数会出现重复子序列的情况。
当然,也有很多更加先进的随机数产生算法出现,比如numpy
用的是 Mersenne Twister
等。根据上面的算法现在我们有了均匀分布的随机数,但是如何产生满足其他分布下的随机数呢?
首先我们来看一个简单的例子,假设我们想对下面的二项分布进行采样:
我们如何采样得到
如果
这也很简单,让算机生成0到1之间的随机数
当然这种情况下也有容易解决的例子,比如随机变量都是独立的情形:
我们可以利用前面的算法对每个变量单独进行采样。但是问题并不总是这么简单,当
此时就需要使用一些更加复杂的随机模拟的方法来生成样本。而本节中将要重点介绍的 MCMC
算法和 Gibbs Sampling
算法就是最常用的两种,这两个方法在现代贝叶斯分析中被广泛使用。
要了解这两个算法,我们首先要对马尔科夫链的平稳分布的性质有基本的认识。
马尔科夫链及其平稳分布
马尔科夫链的数学定义如下:
也就是说前一个状态只与当前状态有关,而与其他状态无关,马尔科夫链体现的是状态空间的转换关系,下一个状态只决定于当前的状态。
下面举一个例子。
社会学家经常把人按其经济状况分成3类:下层(lower-class
)、中层(middle-class
)、上层(upper-class
),我们用 1, 2, 3 分别代表这三个阶层。
社会学家们发现决定一个人的收入阶层的最重要的因素就是其父母的收入阶层。如果一个人的收入属于下层类别,那么他的孩子属于下层收入的概率是 0.65, 属于中层收入的概率是 0.28,
属于上层收入的概率是 0.07。事实上,从父代到子代,收入阶层的变化的转移概率如下
使用矩阵的表示方式,转移概率矩阵记为
假设当前这一代人处在下层、中层、上层的人的比例是概率分布向量
他们的孙子代的分布比例将是
假设初始概率分布为
代码如下:
结果如下:
0.2000000 0.3000000 0.5000000
0.2350000 0.4370000 0.3280000
0.2576600 0.4766700 0.2656700
0.2708599 0.4871549 0.2419852
0.2781704 0.4893492 0.2324804
0.2821108 0.4894446 0.2284446
0.2842021 0.4891590 0.2266390
0.2853019 0.4889031 0.2257950
0.2858771 0.4887358 0.2253871
0.2861769 0.4886379 0.2251851
0.2863329 0.4885836 0.2250835
我们发现,最终会趋于稳定值。我们把初始分布改为
0.5000000 0.4000000 0.1000000
0.3970000 0.4440000 0.1590000
0.3437300 0.4658800 0.1903900
0.3161533 0.4769244 0.2069223
0.3018690 0.4825543 0.2155767
0.2944672 0.4854423 0.2200905
0.2906309 0.4869297 0.2224394
0.2886423 0.4876977 0.2236600
0.2876113 0.4880949 0.2242937
0.2870769 0.4883005 0.2246226
0.2867997 0.4884070 0.2247932
最终还是收敛于相同的值。即这个最终分布于初始状态无关。我们迭代转移矩阵
于是引出如下定理:
马尔科夫链定理: 如果一个非周期马尔科夫链具有转移概率矩阵
记
1
2
马尔科夫链定理非常重要,所有的MCMC
方法都是以这个定理作为理论基础的。定理内容有一些需要解释说明的地方:
该定理中马氏链的状态不要求有限,可以是有无穷多个的;
定理中的“非周期“这个概念不解释,因为我们遇到的绝大多数马氏链都是非周期的;
两个状态
i ,j 是连通并非指i 可以直接一步转移到j(Pij>0) ,而是指i 可以通过有限的n 步转移到达j(Pnij>0) 。
马氏链的任何两个状态是连通的含义是指存在一个n ,使得矩阵Pn 中任何一个元素的值大于零。我们用
Xi 表示在马氏链上跳转第i 步所处的状态,如果limn→∞Pnij=π(j) 存在,很容易证明定理第二个结论。
MCMC采样算法
对于给定的概率分布
如果我们能构造一个转移矩阵为
得到一个转移序列
这个想法在1953 年被Metropolis
想到的,首次提出了基于马氏链的蒙特卡罗方法,即Metropolis
算法,并在最早的计算机上编程实现。
Metropolis
算法是首个普适的采样方法,并启发了一系列MCMC
方法。Metropolis
的这篇论文被收录在《统计学中的重大突破》中,Metropolis算法也被遴选为二十世纪的十个最重要的算法之一。
由上一节定理我们看到了,马尔科夫链的收敛性质主要由转移矩阵
如何能做到这一点呢?我们主要使用如下的定理。
细致平稳条件 : 如果非周期马尔科夫链的转移矩阵
则
detailed balance condition
)。 其实这个定理是显而易见的,因为细致平稳条件的物理含义就是对于任何两个状态
从
假设我们已经有一个转移矩阵为
也就是细致平稳条件不成立,所以
取什么样的
于是上述式成立了。 在改造
我们以
以上的MCMC
采样算法已经能很漂亮的工作了,不过它有一个小的问题:马尔科夫链
假设
上式两边同时扩大5倍,等式变为
我们提高了接受率,而细致平稳条件并没有打破。这启发我们可以把细致平稳条件中
故可以
MCMC
采样算法中接受率的改造,我们就得到了最常见的Metropolis-Hastings
算法。
Gibbs Sampling
对于高维的情形,由于接受率
Metropolis-Hastings
算法的效率不够高。能否找到一个转移矩阵
我们先看看二维的情形,假设有一个概率分布
所以得到
即:
基于以上等式,我们发现,在
同样的,如果我们在
于是这个二维空间上的马氏链将收敛到平稳分布
Gibbs Sampling
算法,是Stuart Geman
和Donald Geman
这两兄弟于1984年提出来的,
之所以叫做Gibbs Sampling
是因为他们研究了Gibbs random field
, 这个算法在现代贝叶斯分析中占据重要位置。
如果当前状态为
其它无法沿着单根坐标轴进行的跳转,转移概率都设置为0。 于是Gibbs Smapling
算法可以描述为:
JAVA程序
对应的java程序为:
https://github.com/endymecy/MCMC-sampling
Gibbs采样程序:
package main.java.sample;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Hashtable;
import java.util.Random;
/**
*
* Gibb's Sampling Steps:<br>
* 1. Set every variable to a random value.<br>
* 2. Choose a variable to update. <br>
* 3. Randomly Select (aka "Sample") a new value for the variable based on the
* current conditions. <br>
* 4. Repeat from Step 2.
*
*
*/
public class Gibbs {
private static ArrayList<String> sequences = new ArrayList<String>();
Hashtable<String, Integer> start;
private int motifLength;
/**
* Constructs and performs Gibb's Sampling in order to find repeated motifs.
*
* @param seq
* A String array of the sequences that will be used.
* @param motifLength
* An Integer that shows the length of the motif or pattern we
* are trying to find, this value is given.
*/
public Gibbs(String[] seqArray, int motifLength) {
sequences.addAll(Arrays.asList(seqArray));
this.motifLength = motifLength;
this.start = generateRandomValue();
sample();
System.out.println(start);
}
/**
* This method is repeated 2000 times.
*
* @param start
* A HashTable containing the sequence as a key, and the random
* integer to be used as the value.
*/
private void sample() {
for (int j = 0; j < 2000; j++) {
Random rand = new Random();
int chosenSeqIndex = rand.nextInt(sequences.size());
String chosenSequence = sequences.get(chosenSeqIndex);
ArrayList<Double> scores = new ArrayList<Double>();
// i = possibleStart
for (int i = 0; i < chosenSequence.length() - motifLength + 1; i++) {
String tempMotif = chosenSequence.substring(i, i + motifLength);
double p = calculateP(tempMotif, chosenSeqIndex);
double q = calculateQ(tempMotif, chosenSeqIndex, i);
scores.add(q / p);
}
double sum = 0;
for (double d : scores) {
sum += d;
}
for (int i = 0; i < scores.size(); i++) {
scores.set(0, scores.get(i) / sum);
}
double random = rand.nextDouble();
double dubsum = 0;
for (double d : scores) {
dubsum += d;
if (random == dubsum) {
start.put(chosenSequence, scores.indexOf(d));
}
}
}
}
/**
* Calculates the probability of a letter in this position.
*
* @param tempMotif
* The motif being used for this calculation.
* @param chosenSeqIndex
* The index of the sequence being used for this calculation,
* useful for skipping all of this sequences calculations and
* focusing on the other ones.
* @return A double of the probability of a letter in this position.
*/
private double calculateQ(String tempMotif, int chosenSeqIndex,
int possibleStart) {
double q = 1;
int start = possibleStart;
int end = possibleStart + tempMotif.length();
double denominator = sequences.size() - 1;
for (String s : sequences) {
double numerator = 0;
if (s.equals(sequences.get(chosenSeqIndex)))
continue;
if (end > s.length()) {
q *= 0.01;
continue;
}
String thisMotif = s.substring(start, end);
char[] letters = tempMotif.toCharArray();
char[] seqLetters = thisMotif.toCharArray();
for (int i = 0; i < tempMotif.length(); i++) {
if (letters[i] == seqLetters[i])
numerator++;
}
if (numerator == 0)
q *= 0.01;
else
q *= (numerator / denominator);
}
return q;
}
/**
* Calculates the probability of a letter randomly selected.
*
* To find this value, the method loops through each letter of the selected
* temporary motif, and loops through the other sequences. While looping
* through the other sequences, we find the amount of same letters in each
* other sequence, along with the total length of all other sequences. The
* value P is a product of every result, each result being the amount of
* letters of the same kind over the total amount of letters.
*
* @param tempMotif
* The motif being used for this calculation.
* @param chosenSeqIndex
* The index of the sequence being used for this calculation,
* useful for skipping all of this sequences calculations and
* focusing on the other ones.
* @return A double of the probability of a letter randomly selected.
*/
private double calculateP(String tempMotif, int chosenSeqIndex) {
double p = 1;
for (char c : tempMotif.toCharArray()) {
double sameLetters = 0;
double totalLength = 0;
for (String s : sequences) {
if (s.equals(sequences.get(chosenSeqIndex)))
continue;
char[] seqLetters = s.toCharArray();
for (char x : seqLetters)
if (c == x)
sameLetters++;
totalLength += s.length();
}
p *= (sameLetters / totalLength);
}
return p;
}
/**
* Calculates and stores every random value. Generates a random from 0 to a
* value of each individual sequences length subtracted by the motif length.
*
* @return A HashTable containing the sequence as a key, and the random
* integer to be used as the value.
*/
private Hashtable<String, Integer> generateRandomValue() {
Random rand = new Random();
Hashtable<String, Integer> randomValues = new Hashtable<String, Integer>();
for (String seq : sequences) {
int randomVal = rand.nextInt(seq.length() - motifLength);
randomValues.put(seq, randomVal);
}
return randomValues;
}
public static void main(String[] args) {
String[] data = { "ABCDAAAABDB", "AAAADCBBCA", "DDBCABAAAACBBD",
"AABAAAACCDD" };
int length = 4;
@SuppressWarnings("unused")
Gibbs gibbs = new Gibbs(data, length);
}
}
参考文献
【1】蒙特卡洛方法采样算法
【2】随机采样方法整理与讲解(MCMC、Gibbs Sampling等)
原文来自
endymecy所写程序中的理论部分。
https://github.com/endymecy/MCMC-sampling