java实现Viterbi算法

package com.bj58.dia.rec.gul.wpai.dlpredictonline.impl;

public class Viterbi {
    static private int[] status = { 0, 1, 2 };
    static private int[] observations = { 1, 6, 3, 5, 2, 7, 3, 5, 2, 4 };
    static double[][] transititon_probability = new double[][] { { 1.0 / 3, 1.0 / 3, 1.0 / 3 },
            { 1.0 / 3, 1.0 / 3, 1.0 / 3 }, { 1.0 / 3, 1.0 / 3, 1.0 / 3 } };
    static private double[] pai = new double[] { 1.0 / 3, 1.0 / 3, 1.0 / 3 };
    public static double[][] B = { { 1.0 / 4, 1.0 / 4, 1.0 / 4, 1.0 / 4, 0, 0, 0, 0, },
            { 1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6, 1.0 / 6, 0, 0 },
            { 1.0 / 8, 1.0 / 8, 1.0 / 8, 1.0 / 8, 1.0 / 8, 1.0 / 8, 1.0 / 8, 1.0 / 8 } };

    /**
     * 求解HMM模型
     *
     * @param obs
     *            观测序列
     * @param states
     *            隐状态
     * @param start_p
     *            初始概率(隐状态)
     * @param trans_p
     *            转移概率(隐状态)
     * @param emit_p
     *            发射概率 (隐状态表现为显状态的概率)
     * @return 最可能的序列
     */
    public static int[] compute(int[] obs, int[] states, double[] start_p, double[][] trans_p, double[][] emit_p) {
        int N = B.length;
        double[] delta = new double[states.length];
        int[][] path = new int[states.length][obs.length];

        for (int i = 0; i < N; i++) {
            delta[i] = start_p[i] * emit_p[i][obs[0] - 1];
            path[i][0] = 0;
        }

        for (int t = 1; t < obs.length; ++t) {
            int[][] newpath = new int[states.length][obs.length];
            double temp[] = { delta[0], delta[1], delta[2] };
            for (int j = 0; j < N; j++) {
                double maxProb = -1;
                for (int i = 0; i < N; i++) {
                    double nprob = temp[i] * trans_p[i][j];
                    if (nprob > maxProb) {
                        maxProb = nprob;
                        // 记录路径
                        System.arraycopy(path[i], 0, newpath[j], 0, t);
                        newpath[j][t] = j;
                    }
                    // 记录较大概率
                    delta[j] = maxProb * emit_p[j][obs[t] - 1];
                }
            }

            path = newpath;
        }

        double prob = -1;
        int state = 0;
        for (int i = 0; i < N; i++) {
            if (delta[i] > prob) {
                prob = delta[i];
                state = i;
            }
        }

        return path[state];
    }

    public static void main(String[] args) {
        int[] result = Viterbi.compute(observations, status, pai, transititon_probability, B);
        for (int r : result) {
            System.out.print((r + 1) + " ");
        }
    }



}

猜你喜欢

转载自blog.csdn.net/wshzd/article/details/89338437