隐马尔可夫模型之:前向算法

隐马尔可夫模型(hidden markov model 简称hmm)广泛应用于语音识别,机器翻译等领域。

隐马尔可夫模型的具体定义,请参考著名论文《A tutorial on Hidden Markov Models and selected applications in speech recognition》,在阅读以下内容之前,建议读者阅读这篇论文的第I II III 节,理论性的东西在此不做赘述。

hmm通常解决以下三类问题:

1.给定一个hmm和观察序列,判断生成这个观察序列的可能性;
2.给定一个hmm和观察序列,给出最可能生成这个观察序列的隐藏序列;
3.给定一个观察序列,训练一个hmm。

第1个问题,通常称为评估问题,可以用前向算法(forward algorithm)来解决,使用了动态规划技术,将该问题的时间复杂度降为O(N*N*T),其中N为隐藏状态的个数,T为给定的观察序列的长度,下面给出java代码:

package hmm;

import java.util.HashMap;
import java.util.Map;

/**
 * 隐马尔可夫模型
 * @author xuguanglv
 *
 */
public class Hmm {
	//初始概率向量
	private static double[] pai = {0.63, 0.17, 0.20};
	
	//状态转移矩阵
	private static double[][] A = {{0.500, 0.375, 0.125},
							        {0.250, 0.125, 0.625},
							        {0.250, 0.375, 0.375}};
	
	//混淆矩阵
	private static double[][] B = {{0.60, 0.20, 0.15, 0.05},
							        {0.25, 0.25, 0.25, 0.25},
							        {0.05, 0.10, 0.35, 0.50}};
	
	//隐藏状态索引
	private static Map<String, Integer> hiddenStateIndex = new HashMap<String, Integer>();
	static{
		hiddenStateIndex.put("S(0)", 0);
		hiddenStateIndex.put("S(1)", 1);
		hiddenStateIndex.put("S(2)", 2);
	}
	
	//观察状态索引
	private static Map<String, Integer> observableStateIndex = new HashMap<String, Integer>();
	static{
		observableStateIndex.put("O(0)", 0);
		observableStateIndex.put("O(1)", 1);
		observableStateIndex.put("O(2)", 2);
		observableStateIndex.put("O(3)", 3);
	}
	
	//前向算法 根据观察序列和已知的隐马尔可夫模型 返回这个模型生成这个观察序列的概率
	//alpha[t][j]表示t时刻由隐藏状态S(j)生成观察状态O(t)的概率
	public static double forward(String[] observedSequence){
		double[][] alpha = new double[observedSequence.length][A.length];
		
		//利用动态规划计算出alpha数组
		//初始化
		for(int i = 0; i <= A.length - 1; i++){
			int index = observableStateIndex.get(observedSequence[0]);
			alpha[0][i] = pai[i] * B[i][index];
		}
		for(int t = 1; t <= observedSequence.length - 1; t++){
			for(int j = 0; j <= A.length - 1; j++){
				double sum = 0;
				for(int i = 0; i <= A.length - 1; i++){
					sum += (alpha[t - 1][i] * A[i][j]);
				}
				int index = observableStateIndex.get(observedSequence[t]);
				alpha[t][j] = sum * B[j][index];
			}
		}
		double prob = 0;
		for(int i = 0; i <= A.length - 1; i++){
			prob += alpha[observedSequence.length - 1][i];
		}
		return prob;
	}
	
	public static void main(String[] args){
		String[] observedSequence = {"O(0)", "O(2)", "O(3)"};
		System.out.println(forward(observedSequence));
	}
}

猜你喜欢

转载自xglv2013.iteye.com/blog/2306970