当前位置:   article > 正文

深度学习基础技术分析6:LSTM(含代码分析)_lstm结构代码

lstm结构代码

1. 模型图示

LSTM 模型如 图1 所示。横向穿过 cell 上部的线分别称作 c \mathbf{c} c 总线,下部的线称为 h \mathbf{h} h 总线,这意味着 c t − 1 \mathbf{c}_{t - 1} ct1 h t − 1 \mathbf{h}_{t - 1} ht1 会对 t t t 时刻的计算产生影响 。其中:

  1. x t x_t xt 与下

在这里插入图片描述图1. LSTM 模型

2. 相关技术

LSTM 从名称来看,是用于处理长短时序。

3. 代码分析

程序代码见: https://github.com/garstka/char-rnn-java
为了学习它, 我又来逐个方法来分析.

// 前向传播核心代码
// acts 根据字符串存取实型二维数组
public void active(int t, Map<String, DoubleMatrix> acts) {
    // 获取 t 时刻输入
    DoubleMatrix x = acts.get("x" + t);
    // 上一时刻的 h 和 c
    DoubleMatrix preH = null, preC = null;
    if (t == 0) {
        preH = new DoubleMatrix(1, getOutSize());
        preC = preH.dup();
    } else {
        preH = acts.get("h" + (t - 1));
        preC = acts.get("c" + (t - 1));
    }
    
    DoubleMatrix i = Activer.logistic(x.mmul(Wxi).add(preH.mmul(Whi)).add(preC.mmul(Wci)).add(bi));
    DoubleMatrix f = Activer.logistic(x.mmul(Wxf).add(preH.mmul(Whf)).add(preC.mmul(Wcf)).add(bf));
    DoubleMatrix gc = Activer.tanh(x.mmul(Wxc).add(preH.mmul(Whc)).add(bc));
    DoubleMatrix c = f.mul(preC).add(i.mul(gc));
    DoubleMatrix o = Activer.logistic(x.mmul(Wxo).add(preH.mmul(Who)).add(c.mmul(Wco)).add(bo));
    DoubleMatrix gh = Activer.tanh(c);
    DoubleMatrix h = o.mul(gh);
    
    // 存储各个二维矩阵
    acts.put("i" + t, i);
    acts.put("f" + t, f);
    acts.put("gc" + t, gc);
    acts.put("c" + t, c);
    acts.put("o" + t, o);
    acts.put("gh" + t, gh);
    acts.put("h" + t, h);
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

在我运行的程序中, x t x_t xt 为 one-hot 编码的 1 × 62 1 \times 62 1×62 向量, i t i_t it h t h_t ht 均为 1 × 100 1 \times 100 1×100 向量.

代码所表示的信息比 图1 更丰富。矩阵变量之间要运算,很多时候要乘以权重矩阵。为使得结构更清晰,图1 牺牲了表达的准确性。以下将向量的计算翻译成数学表达式,这些向量都会被存储在模型中。

  1. 向量 i t \mathbf{i}_t it 表示 t t t 时刻输入:
    i t = σ ( W x i ⋅ x t + W h i ⋅ h t − 1 + W c i ⋅ c t − 1 + b i ) (1) \mathbf{i}_t = \sigma(\mathbf{W}^{xi} \cdot \mathbf{x}_t + \mathbf{W}^{hi} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{ci} \cdot \mathbf{c}_{t - 1} + bi) \tag{1} it=σ(Wxixt+Whiht1+Wcict1+bi)(1)
  2. 向量 f t \mathbf{f}_t ft 表示遗忘:
    i t = σ ( W x f ⋅ x t + W h f ⋅ h t − 1 + W c f ⋅ c t − 1 + b f ) (2) \mathbf{i}_t = \sigma(\mathbf{W}^{xf} \cdot \mathbf{x}_t + \mathbf{W}^{hf} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{cf} \cdot \mathbf{c}_{t - 1} + bf) \tag{2} it=σ(Wxfxt+Whfht1+Wcfct1+bf)(2)
  3. 向量 g c t \mathbf{gc}_t gct 表示
    g c t = t a n h ( W x c ⋅ x t + W h c ⋅ h t − 1 + b c ) (3) \mathbf{gc}_t = tanh(\mathbf{W}^{xc} \cdot \mathbf{x}_t + \mathbf{W}^{hc} \cdot \mathbf{h}_{t - 1} + bc) \tag{3} gct=tanh(Wxcxt+Whcht1+bc)(3)
  4. 向量 c t \mathbf{c}_t ct 表示
    c t = tanh ⁡ ( f ⊙ c t − 1 + i t ⊙ g c t ) (4) \mathbf{c}_t = \tanh(\mathbf{f} \odot \mathbf{c}_{t - 1} + \mathbf{i}_{t} \odot \mathbf{gc}_t) \tag{4} ct=tanh(fct1+itgct)(4)
  5. 向量 o t \mathbf{o}_t ot 表示
    o t = σ ( W x o ⋅ x t + W h o ⋅ h t − 1 + W c o ⋅ c t + b o ) (5) \mathbf{o}_t = \sigma(\mathbf{W}^{xo} \cdot \mathbf{x}_t + \mathbf{W}^{ho} \cdot \mathbf{h}_{t - 1} + \mathbf{W}^{co} \cdot \mathbf{c}_t + bo) \tag{5} ot=σ(Wxoxt+Whoht1+Wcoct+bo)(5)
  6. 向量 g h t \mathbf{gh}_t ght 表示
    g h t = tanh ⁡ ( c t ) (6) \mathbf{gh}_t = \tanh(\mathbf{c}_t) \tag{6} ght=tanh(ct)(6)
  7. 向量 h t \mathbf{h}_t ht 表示本时刻输出.
    h t = o t ⊙ g h t (7) \mathbf{h}_t = \mathbf{o}_t \odot \mathbf{gh}_t \tag{7} ht=otght(7)
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/335114
推荐阅读
相关标签
  

闽ICP备14008679号