当前位置:   article > 正文

modeling_bert.py

modeling_bert.py

BertEmbeddings类

在这里插入图片描述

pytorch中nn.Embedding原理及使用
https://www.jianshu.com/p/63e7acc5e890

token embeddings

在这里插入图片描述

def __init__(self, config):
    super().__init__()
# vocab_size默认为30522;hidden_size默认为768
# 【word_embeddings】词典大小:vocab_size;向量维度:hidden_size
# token embedding层是要将各个词转换成固定维度的向量。在BERT中,每个词会被转换成768维的向量表示。
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)

def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
    if input_ids is not None:
        input_shape = input_ids.size()
    else:
        input_shape = inputs_embeds.size()[:-1]
	    
	if inputs_embeds is None:
        inputs_embeds = self.word_embeddings(input_ids)     # wordpiece分词后的token编码
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

segment embeddings

在这里插入图片描述
上图的实现和tokenization_bert.py中的【build_inputs_with_special_tokens函数】【get_special_tokens_mask函数】【create_token_type_ids_from_sequences函数】有关。

# 【token_type_embeddings】词典大小:type_vocab_size;向量维度:hidden_size
# NSP操作:判断属于句子A还是句子B
def __init__(self, config):
    super().__init__()
    self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
        
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
    if input_ids is not None:
        input_shape = input_ids.size()
    else:
        input_shape = inputs_embeds.size()[:-1]
    
    if token_type_ids is None:	# 如果这个参数是None,则不做NSP下个句子预测处理,句子类型全用0表示
        token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
    token_type_embeddings = self.token_type_embeddings(token_type_ids)      # 句子对的分割编码
	
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

position embeddings

BERT能够处理最长512个token的输入序列。 论文作者通过让BERT在各个位置上学习一个向量表示来讲序列顺序的信息编码进来。
这意味着Position Embeddings layer 实际上就是一个大小为 (512, 768)
的lookup表,表的第一行是代表第一个序列的第一个位置,第二行代表序列的第二个位置,以此类推。 因此,如果有这样两个句子“Hello
world” 和“Hi there”, “Hello” 和“Hi”会由完全相同的position
embeddings,因为他们都是句子的第一个词。同理,“world” 和“there”也会有相同的position embedding。
引用自 https://www.cnblogs.com/d0main/p/10447853.html

# 【position_embeddings】词典大小:max_position_embeddings;向量维度:hidden_size
# position_embeddings是一个大小为(512, 768)的lookup表,表的第一行是代表第一个序列的第一个位置,以此类推
# 两个句子的同一个位置的位置编码相同
def __init__(self, config):
    super().__init__()
    self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
     if input_ids is not None:
        input_shape = input_ids.size()
     else:
        input_shape = inputs_embeds.size()[:-1]

     seq_length = input_shape[1]

     if position_ids is None:
        position_ids = self.position_ids[:, :seq_length]

	 position_embeddings = self.position_embeddings(position_ids)    # 位置编码

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

BertSelfAttention类

hasattr() 函数用于判断对象是否包含对应的属性。
hasattr(object, name)
参数:object – 对象;name – 字符串,属性名。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/348648
推荐阅读
相关标签
  

闽ICP备14008679号