当前位置:   article > 正文

学习笔记——torch.nn.RNN()

torch.nn.rnn

1. 调用方式

官方文档RNN — PyTorch 1.13 documentation

用于实现RNN层,并可通过传入参数实现多层堆叠(深度循环神经网络)、双向传播(双向循环神经网络):

示例(包含输入输出格式):

  1. import torch
  2. import torch.nn as nn
  3. myrnn = nn.RNN(4,3,2,batch_first=True) #input_size,hidden_size,num_layers
  4. print("myrnn:", myrnn)
  5. input = torch.randn(2,3,4) #输入数据集格式(batch_size, sequence_length, input_size(已限定为4))
  6. print("input:", input)
  7. output, h_n = myrnn(input) #output为每个时刻的隐藏状态,格式为(batch_size,sequence_length,hidden_size);h_n为最后时刻的隐藏状态,格式为(num_layers,batch_size,hidden_size)
  8. print("output:", output)
  9. print("h_n:", h_n)

输出:

* hidden_size类似于全连接网络的结点个数

2. 关于batch_first

输入数据集格式(batch_first默认False):

其中,N=batch_size批量大小,L=sequence_length序列长度,H=input_size输入尺寸。

默认顺序为(sequence_length,batch_size,input_size),与通常batch_size在第一维度有所不同,原因参考MultiHeadAttension源码解析——batch_first参数含义_coder1479的博客-CSDN博客读PyTorch源码学习RNN(1) - 知乎可知:

“由于RNN是序列模型,只有 t1 时刻计算完成,才能进入 t2 时刻,而"batch"就体现在每个时刻 ti 的计算过程中,图中 t1 时刻将["A", "D"]作为当前时刻的batch数据,t2 时刻将["B", "E"]作为当前时刻的batch数据,可想而知,"A"与"D"在内存中相邻比"A"与"B"相邻更合理,这样取数据时才更高效。” 

实际使用中可将batch_first设置为True,按照(batch_size,sequence_length,input_size)顺序传入参数,此时函数会自动将其转换成默认的顺序(sequence_length,batch_size,input_size),并且在输出结果的时候,再转换回来。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/897203
推荐阅读
相关标签
  

闽ICP备14008679号