当前位置:   article > 正文

pytorch nn.Embedding的用法和理解_pytorch里nn.embedding的vocab是多少

pytorch里nn.embedding的vocab是多少

官方文档的示例:

  1. >>> # an Embedding module containing 10 tensors of size 3
  2. >>> embedding = nn.Embedding(10, 3)
  3. >>> # a batch of 2 samples of 4 indices each
  4. >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
  5. >>> embedding(input)
  6. tensor([[[-0.0251, -1.6902, 0.7172],
  7. [-0.6431, 0.0748, 0.6969],
  8. [ 1.4970, 1.3448, -0.9685],
  9. [-0.3677, -2.7265, -0.1685]],
  10. [[ 1.4970, 1.3448, -0.9685],
  11. [ 0.4362, -0.4004, 0.9400],
  12. [-0.6431, 0.0748, 0.6969],
  13. [ 0.9124, -2.3616, 1.1151]]])

我不太懂的是定义完nn.Embedding(num_embeddings-词典长度,embedding_dim-向量维度)之后,为什么就可以直接使用embedding(input)进行输入。
我们来仔细看看:

>>> embedding = nn.Embedding(10, 3)      

构造一个(假装)vocab size=10,每个vocab用3-d向量表示的table

  1. >>> embedding.weight
  2. Parameter containing:
  3. tensor([[ 1.2402, -1.0914, -0.5382],
  4. [-1.1031, -1.2430, -0.2571],
  5. [ 1.6682, -0.8926, 1.4263],
  6. [ 0.8971, 1.4592, 0.6712],
  7. [-1.1625, -0.1598, 0.4034],
  8. [-0.2902, -0.0323, -2.2259],
  9. [ 0.8332, -0.2452, -1.1508],
  10. [ 0.3786, 1.7752, -0.0591],
  11. [-1.8527, -2.5141, -0.4990],
  12. [-0.6188, 0.5902, -0.0860]], requires_grad=True)

可以看做每行是一个词汇的向量表示!

  1. >>> embedding.weight.size
  2. torch.Size([10, 3])

和nn.Embedding处的定义一致

  1. >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]])
  2. >>> input
  3. tensor([[1, 2, 4, 5],
  4. [4, 3, 2, 9]])

牢记:input是indices

  1. >>> input.shape
  2. torch.Size([2, 4])

 Input size表示这批有2个句子,每个句子由4个单词构成

  1. >>> a = embedding(input)
  2. >>> a
  3. tensor([[[-1.1031, -1.2430, -0.2571],
  4. [ 1.6682, -0.8926, 1.4263],
  5. [-1.1625, -0.1598, 0.4034],
  6. [-0.2902, -0.0323, -2.2259]],
  7. [[-1.1625, -0.1598, 0.4034],
  8. [ 0.8971, 1.4592, 0.6712],
  9. [ 1.6682, -0.8926, 1.4263],
  10. [-0.6188, 0.5902, -0.0860]]], grad_fn=<EmbeddingBackward>)

a=embedding(input)是去embedding.weight中取对应index的词向量!
看a的第一行,input处index=1,对应取出weight中index=1的那一行。其实就是按index取词向量!

  1. >>> a.size()
  2. torch.Size([2, 4, 3])

取出来之后编程2*4*3的张量。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/363864
推荐阅读
相关标签
  

闽ICP备14008679号