自然语言处理系列之Viterbi算法

  前面已经介绍了隐马尔可夫模型,本篇博文主要是介绍用 viterbi 算法来解决 HMM 中的预测问题,也称为解码问题。
  维特比算法实际是用动态规划解隐马尔可夫模型预测问题,即用动态规划(dynamic programming)求概率最大路径(最优路径)。这时一条路径对应着一个状态序列。
  根据动态规划原理,最优路径具有这样的特性:如果最优路径在时刻t通过 (it) ,那么这一路径从 it 到终点 iT 的部分路径,对于从 it iT 的所有可能的部分路径来说,必须是最优的。因为假如不是这样,那么从 i1 到终点 iT 就有另一条更好的部分路径存在,如果把它和 i1 到终点 it 的部分路径连接起来,就会形成一条比原来的路径更优的路径,这是矛盾的。依据这一原理,我们只需从时刻t=1开始,递推地计算在时刻t状态为i的各条部分路径的最大概率,直至得到时刻 t=T 状态为i的各条路径的最大概率。时刻 t=T 的最大概率即为最优路径的概率 P ,最优路径的终结点 iT 也同时得到。之后,为了找出最优路径的各个结点,从终结点 iT 开始,由后向前逐步求得结点 iT1,...,i1 得到最优路径这就是维特比算法。

  • viterbi 算法
    输入:模型 λ=(A,B,π) 和观测 O=(o1,o2,...,oT) ;
    输出:最优路径 (i1,...,iT1,iT) .
    (1) 初始化

    δ1(i)=πibi(oi),i=1,2,...,N

    ψ1(i)=0,i=1,2,...,N

    (2) 递推.对 t=2,3,...,T
    δt(i)=max[δt1(j)aji]bi(ot),i=1,2,..,N;1jN

    ψt(i)=argmax[δt1(j)aji],i=1,2,...,N;1jN

    (3) 终止
    P=maxδT(i),1jN

    iT=argmax[δT(i)],1jN

    (4)最优路径回溯. 对 t=T1,T2,...,1
    it=ψt+1(it+1)

  • viterbi算法实现

package com.feng.nlp.algorithm;

import java.util.*;

/**
 * Created by lionel on 17/4/11.
 */
public class Viterbi {
    public static List<String> compute(String[] observe, String[] status, double[] start_p, double[][] transfer_p, double[][] observe_p) {
        double[][] theta = new double[observe.length][status.length];
        int[][] delta = new int[observe.length][status.length];
        transfermation(start_p, transfer_p, observe_p);
        for (int j = 0; j < status.length; j++) {
            theta[0][j] = start_p[j] + observe_p[j][0];
            delta[0][j] = 0;
        }
        Map<String, Integer> map = new HashMap<String, Integer>();
        int index = 0;
        for (String ele : observe) {
            if (map.containsKey(ele)) {
                continue;
            }
            map.put(ele, index);
            index++;
        }

        for (int i = 1; i < observe.length; i++) {
            for (int j = 0; j < status.length; j++) {
                int direction = 0;
                double prob = Double.MAX_VALUE;
                for (int k = 0; k < status.length; k++) {
                    double tmpProb = theta[i - 1][k] + transfer_p[k][j] + observe_p[j][map.get(observe[i])];
                    if (tmpProb < prob) {
                        prob = tmpProb;
                        direction = k;
                        theta[i][j] = prob;
                    }
                }
                delta[i][j] = direction;
            }
        }
//        for (int i = 0; i < theta.length; i++) {
//            for (int j = 0; j < theta[i].length; j++) {
//                System.out.print(theta[i][j] + " ");
//            }
//            System.out.println();
//        }
        double prob = Double.MAX_VALUE;
        int pos = 0;
        for (int j = 0; j < status.length; j++) {
            if (theta[observe.length - 1][j] < prob) {
                prob = theta[observe.length - 1][j];
                pos = j;
            }
        }
        List<String> res = new ArrayList<String>();
        res.add(status[pos]);
        //回溯路径
        for (int i = observe.length - 1; i > 0; i--) {
            res.add(status[delta[i][pos]]);
            pos = delta[i][pos];
        }

        Collections.reverse(res);
        return res;
    }

    public static void transfermation(double[] start_p, double[][] transfer_p, double[][] observe_p) {
        for (int i = 0; i < start_p.length; ++i) {
            start_p[i] = -Math.log(start_p[i]);
        }
        for (int i = 0; i < transfer_p.length; ++i) {
            for (int j = 0; j < transfer_p[i].length; ++j) {
                transfer_p[i][j] = -Math.log(transfer_p[i][j]);
            }
        }
        for (int i = 0; i < observe_p.length; ++i) {
            for (int j = 0; j < observe_p[i].length; ++j) {
                observe_p[i][j] = -Math.log(observe_p[i][j]);
            }
        }
    }


    public static void main(String[] args) {
        String[] observe = {"红", "白", "红"};
        String[] status = {"1", "2", "3"};
        double[] start_p = new double[]{0.2, 0.4, 0.4};
        double[][] transfer_p = new double[][]{
                {0.5, 0.2, 0.3},
                {0.3, 0.5, 0.2},
                {0.2, 0.3, 0.5}
        };
        double[][] observe_p = new double[][]{
                {0.5, 0.5},
                {0.4, 0.6},
                {0.7, 0.3}
        };
        List<String> result = compute(observe, status, start_p, transfer_p, observe_p);
        System.out.println(result);//[3, 3, 3]

    }
}

  测试用例来源于李航老师的《统计机器学习》的例子。

  • 参考资料:《统计机器学习》,李航

猜你喜欢

转载自blog.csdn.net/lionel_fengj/article/details/70196670