当前位置:   article > 正文

自然语言处理系列之Viterbi算法_语音识别viterbi之后怎么处理

语音识别viterbi之后怎么处理

  前面已经介绍了隐马尔可夫模型,本篇博文主要是介绍用 viterbi 算法来解决 HMM 中的预测问题,也称为解码问题。
  维特比算法实际是用动态规划解隐马尔可夫模型预测问题,即用动态规划(dynamic programming)求概率最大路径(最优路径)。这时一条路径对应着一个状态序列。
  根据动态规划原理,最优路径具有这样的特性:如果最优路径在时刻t通过(it),那么这一路径从it到终点iT的部分路径,对于从itiT的所有可能的部分路径来说,必须是最优的。因为假如不是这样,那么从i1到终点iT就有另一条更好的部分路径存在,如果把它和i1到终点it的部分路径连接起来,就会形成一条比原来的路径更优的路径,这是矛盾的。依据这一原理,我们只需从时刻t=1开始,递推地计算在时刻t状态为i的各条部分路径的最大概率,直至得到时刻t=T状态为i的各条路径的最大概率。时刻t=T的最大概率即为最优路径的概率P,最优路径的终结点iT也同时得到。之后,为了找出最优路径的各个结点,从终结点iT开始,由后向前逐步求得结点iT1,...,i1得到最优路径这就是维特比算法。

  • viterbi 算法
    输入:模型λ=(A,B,π)和观测O=(o1,o2,...,oT);
    输出:最优路径(i1,...,iT1,iT).
    (1) 初始化

    δ1(i)=πibi(oi),i=1,2,...,N

    ψ1(i)=0,i=1,2,...,N

    (2) 递推.对t=2,3,...,T
    δt(i)=max[δt1(j)aji]bi(ot),i=1,2,..,N;1jN

    ψt(i)=argmax[δt1(j)aji],i=1,2,...,N;1jN

    (3) 终止
    P=maxδT(i),1jN

    iT=argmax[δT(i)],1jN

    (4)最优路径回溯. 对t=T1,T2,...,1
    it=ψt+1(it+1)

  • viterbi算法实现

package com.feng.nlp.algorithm;

import java.util.*;

/**
 * Created by lionel on 17/4/11.
 */
public class Viterbi {
    public static List<String> compute(String[] observe, String[] status, double[] start_p, double[][] transfer_p, double[][] observe_p) {
        double[][] theta = new double[observe.length][status.length];
        int[][] delta = new int[observe.length][status.length];
        transfermation(start_p, transfer_p, observe_p);
        for (int j = 0; j < status.length; j++) {
            theta[0][j] = start_p[j] + observe_p[j][0];
            delta[0][j] = 0;
        }
        Map<String, Integer> map = new HashMap<String, Integer>();
        int index = 0;
        for (String ele : observe) {
            if (map.containsKey(ele)) {
                continue;
            }
            map.put(ele, index);
            index++;
        }

        for (int i = 1; i < observe.length; i++) {
            for (int j = 0; j < status.length; j++) {
                int direction = 0;
                double prob = Double.MAX_VALUE;
                for (int k = 0; k < status.length; k++) {
                    double tmpProb = theta[i - 1][k] + transfer_p[k][j] + observe_p[j][map.get(observe[i])];
                    if (tmpProb < prob) {
                        prob = tmpProb;
                        direction = k;
                        theta[i][j] = prob;
                    }
                }
                delta[i][j] = direction;
            }
        }
//        for (int i = 0; i < theta.length; i++) {
//            for (int j = 0; j < theta[i].length; j++) {
//                System.out.print(theta[i][j] + " ");
//            }
//            System.out.println();
//        }
        double prob = Double.MAX_VALUE;
        int pos = 0;
        for (int j = 0; j < status.length; j++) {
            if (theta[observe.length - 1][j] < prob) {
                prob = theta[observe.length - 1][j];
                pos = j;
            }
        }
        List<String> res = new ArrayList<String>();
        res.add(status[pos]);
        //回溯路径
        for (int i = observe.length - 1; i > 0; i--) {
            res.add(status[delta[i][pos]]);
            pos = delta[i][pos];
        }

        Collections.reverse(res);
        return res;
    }

    public static void transfermation(double[] start_p, double[][] transfer_p, double[][] observe_p) {
        for (int i = 0; i < start_p.length; ++i) {
            start_p[i] = -Math.log(start_p[i]);
        }
        for (int i = 0; i < transfer_p.length; ++i) {
            for (int j = 0; j < transfer_p[i].length; ++j) {
                transfer_p[i][j] = -Math.log(transfer_p[i][j]);
            }
        }
        for (int i = 0; i < observe_p.length; ++i) {
            for (int j = 0; j < observe_p[i].length; ++j) {
                observe_p[i][j] = -Math.log(observe_p[i][j]);
            }
        }
    }


    public static void main(String[] args) {
        String[] observe = {"红", "白", "红"};
        String[] status = {"1", "2", "3"};
        double[] start_p = new double[]{0.2, 0.4, 0.4};
        double[][] transfer_p = new double[][]{
                {0.5, 0.2, 0.3},
                {0.3, 0.5, 0.2},
                {0.2, 0.3, 0.5}
        };
        double[][] observe_p = new double[][]{
                {0.5, 0.5},
                {0.4, 0.6},
                {0.7, 0.3}
        };
        List<String> result = compute(observe, status, start_p, transfer_p, observe_p);
        System.out.println(result);//[3, 3, 3]

    }
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104

  测试用例来源于李航老师的《统计机器学习》的例子。

  • 参考资料:《统计机器学习》,李航
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/980453
推荐阅读
相关标签
  

闽ICP备14008679号