当前位置:   article > 正文

BERT Pytorch版本 源码解析(二)_input_ids.size()

input_ids.size()

BERT Pytorch版本 源码解析(二)

四、BertEmbedding 类解析

BertEmbedding部分是组成 BertModel 的第一部分,今天就来讲讲 BertEmbedding 的内部实现细节。

4.1、Embedding 的组成以及设置

  1. def __init__(self, config):
  2. super(BertEmbeddings, self).__init__()
  3. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
  4. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  5. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  6. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  7. # any TensorFlow checkpoint file
  8. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  9. self.dropout = nn.Dropout(config.hidden_dropout_prob)

上面的代码是 BertEmbedding 类的初始化函数,在这块很明显 BertEmbedding 并似乎并没有很特别的地方。总的是设置了三种类型的 embedding,分别是word_embedding,position_embedding,token_type_embedding三种组成。首先,这三种embedding都是用pytorch自带的nn.Embedding 随机生成的,而且它们的向量长度都是 config.hidden_size。之后是一个常见的LayerNorm 以及 Dropout层,这部分就不解释了。

4.2、具体实现

  1. def forward(self, input_ids, token_type_ids=None):
  2. seq_length = input_ids.size(1)
  3. position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
  4. position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
  5. if token_type_ids is None:
  6. token_type_ids = torch.zeros_like(input_ids)
  7. words_embeddings = self.word_embeddings(input_ids)
  8. position_embeddings = self.position_embeddings(position_ids)
  9. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  10. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  11. embeddings = self.LayerNorm(embeddings)
  12. embeddings = self.dropout(embeddings)
  13. return embeddings

首先输入是input_ids或者token_type_ids,input_ids是一个[Batch_size, Seq_length]维度的向量,每一个元素表示对应词表中的index,token_type_ids是对于一个输入存在两个句子的情况,利用 0 和 1 来区分第几个句子的,所以这个部分其实对于大部分任务来说是可以省略的。

然后是关于position_ids的生成,它是自动生成的一个向量,torch.arange(seq_length)是自动生成一个从0开始到seq_length - 1的长度为seq_length的向量。

如果 token_type_ids 是None的情况下则自动生成一个全为0的向量,即所有的输入都是单句的输入。

之后就是利用nn.Embedding来生成三个[Batch_size, Seq_length, Hidden_size]的向量,然后将三个向量进行叠加操作之后进行LayerNorm以及Dropout操作,这就是BertEmbedding的工作原理。

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/312422
推荐阅读
相关标签
  

闽ICP备14008679号