赞
踩
def forward(self, x, mask, query_embed, pos_embed): """Forward function for `Transformer`. Args: x (Tensor): Input query with shape [bs, c, h, w] where c = embed_dims. mask (Tensor): The key_padding_mask used for encoder and decoder, with shape [bs, h, w]. query_embed (Tensor): The query embedding for decoder, with shape [num_query, c]. pos_embed (Tensor): The positional encoding for encoder and decoder, with the same shape as `x`. Returns: tuple[Tensor]: results of decoder containing the following tensor. - out_dec: Output from decoder. If return_intermediate_dec \ is True output has shape [num_dec_layers, bs, num_query, embed_dims], else has shape [1, bs, \ num_query, embed_dims]. - memory: Output results from encoder, with shape \ [bs, embed_dims, h, w]. """ bs, c, h, w = x.shape # use `view` instead of `flatten` for dynamically exporting to ONNX x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) query_embed = query_embed.unsqueeze(1).repeat( 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] memory = self.encoder( # 编码 query=x, key=None, value=None, query_pos=pos_embed, query_key_padding_mask=mask) target = torch.zeros_like(query_embed) # out_dec: [num_layers, num_query, bs, dim] out_dec = self.decoder( # 解码 query=target, key=memory, value=memory, key_pos=pos_embed, query_pos=query_embed, key_padding_mask=mask) out_dec = out_dec.transpose(1, 2) memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) return out_dec, memory
def forward(self, query, key=None, value=None, query_pos=None, key_pos=None, attn_masks=None, query_key_padding_mask=None, key_padding_mask=None, **kwargs): """Forward function for `TransformerDecoderLayer`. **kwargs contains some specific arguments of attentions. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . value (Tensor): The value tensor with same shape as `key`. query_pos (Tensor): The positional encoding for `query`. Default: None. key_pos (Tensor): The positional encoding for `key`. Default: None. attn_masks (List[Tensor] | None): 2D Tensor used in calculation of corresponding attention. The length of it should equal to the number of `attention` in `operation_order`. Default: None. query_key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_queries]. Only used in `self_attn` layer. Defaults to None. key_padding_mask (Tensor): ByteTensor for `query`, with shape [bs, num_keys]. Default: None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims]. """ norm_index = 0 attn_index = 0 ffn_index = 0 identity = query # [100,bs,256] if attn_masks is None: attn_masks = [None for _ in range(self.num_attn)] elif isinstance(attn_masks, torch.Tensor): attn_masks = [ copy.deepcopy(attn_masks) for _ in range(self.num_attn) ] warnings.warn(f'Use same attn_mask in all attentions in ' f'{self.__class__.__name__} ') else: assert len(attn_masks) == self.num_attn, f'The length of ' \ f'attn_masks {len(attn_masks)} must be equal ' \ f'to the number of attention in ' \ f'operation_order {self.num_attn}' for layer in self.operation_order: if layer == 'self_attn': # 编码的时候用这个 temp_key = temp_value = query # x[100,bs,256],这里 query = self.attentions[attn_index]( query, temp_key, temp_value, identity if self.pre_norm else None, query_pos=query_pos, key_pos=query_pos, attn_mask=attn_masks[attn_index], key_padding_mask=query_key_padding_mask, **kwargs) # [bs,100,256] attn_index += 1 # identity = query elif layer == 'norm': query = self.norms[norm_index](query) norm_index += 1 elif layer == 'cross_attn': query = self.attentions[attn_index]( query, # x key, # None value, # None identity if self.pre_norm else None, query_pos=query_pos, key_pos=key_pos, attn_mask=attn_masks[attn_index], # None key_padding_mask=key_padding_mask, # 注意 **kwargs) attn_index += 1 identity = query elif layer == 'ffn': query = self.ffns[ffn_index]( query, identity if self.pre_norm else None) ffn_index += 1 return query
memory = self.encoder( # 编码
query=x,
key=None,
value=None,
query_pos=pos_embed,
query_key_padding_mask=mask)
def forward(self, query, key=None, value=None, identity=None, query_pos=None, key_pos=None, attn_mask=None, key_padding_mask=None, **kwargs): """Forward function for `MultiheadAttention`. **kwargs allow passing a more general data flow when combining with other operations in `transformerlayer`. Args: query (Tensor): The input query with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. key (Tensor): The key tensor with shape [num_keys, bs, embed_dims] if self.batch_first is False, else [bs, num_keys, embed_dims] . If None, the ``query`` will be used. Defaults to None. value (Tensor): The value tensor with same shape as `key`. Same in `nn.MultiheadAttention.forward`. Defaults to None. If None, the `key` will be used. identity (Tensor): This tensor, with the same shape as x, will be used for the identity link. If None, `x` will be used. Defaults to None. query_pos (Tensor): The positional encoding for query, with the same shape as `x`. If not None, it will be added to `x` before forward function. Defaults to None. key_pos (Tensor): The positional encoding for `key`, with the same shape as `key`. Defaults to None. If not None, it will be added to `key` before forward function. If None, and `query_pos` has the same shape as `key`, then `query_pos` will be used for `key_pos`. Defaults to None. attn_mask (Tensor): ByteTensor mask with shape [num_queries, num_keys]. Same in `nn.MultiheadAttention.forward`. Defaults to None. key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. Defaults to None. Returns: Tensor: forwarded results with shape [num_queries, bs, embed_dims] if self.batch_first is False, else [bs, num_queries embed_dims]. """ if key is None: key = query if value is None: value = key if identity is None: identity = query if key_pos is None: if query_pos is not None: # use query_pos if key_pos is not available if query_pos.shape == key.shape: key_pos = query_pos else: warnings.warn(f'position encoding of key is' f'missing in {self.__class__.__name__}.') if query_pos is not None: query = query + query_pos # 加位置编码 if key_pos is not None: key = key + key_pos # 加位置编码 # Because the dataflow('key', 'query', 'value') of # ``torch.nn.MultiheadAttention`` is (num_query, batch, # embed_dims), We should adjust the shape of dataflow from # batch_first (batch, num_query, embed_dims) to num_query_first # (num_query ,batch, embed_dims), and recover ``attn_output`` # from num_query_first to batch_first. if self.batch_first: query = query.transpose(0, 1) key = key.transpose(0, 1) value = value.transpose(0, 1) out = self.attn( # 进入多头注意力机制,分了8个,256/8=32,后面会cat回去 query=query, # 这个就是transformer的常规操作,就不进入去看了 key=key, # 就是计算自注意力,注意输出就行 value=value, attn_mask=attn_mask, key_padding_mask=key_padding_mask)[0] # 返回[bs,h*w,c] if self.batch_first: out = out.transpose(0, 1) # 转换维度,[h*w,bs,c] return identity + self.dropout_layer(self.proj_drop(out)) # 残差,要dropout,可能数据量太大了吧,也可以防止过拟合
target = torch.zeros_like(query_embed)
# out_dec: [num_layers, num_query, bs, dim]
out_dec = self.decoder( # 解码
query=target, # [num_query, dim]
key=memory, # [h*w,bs,c]
value=memory, # [h*w,bs,c]
key_pos=pos_embed, # [num_query, bs, dim]
query_pos=query_embed, # [num_query, bs, dim]
key_padding_mask=mask) # [bs, h*w]
elif layer == 'cross_attn':
query = self.attentions[attn_index](
query, # [num_query, dim]
key, # [h*w,bs,c]
value, # [h*w,bs,c]
identity if self.pre_norm else None,
query_pos=query_pos, # [num_query, bs, dim]
key_pos=key_pos, # [num_query, bs, dim]
attn_mask=attn_masks[attn_index], #none
key_padding_mask=key_padding_mask, # [bs, h*w]
**kwargs)
进入一个函数,原理应该和自注意力差不多。想仔细看的话在:open-mmlab2/lib/python3.7/site-packages/torch/nn/functional.py文件夹的multi_head_attention_forward函数里
# 这只是为了方便理解,我截取的公式那部分代码,不是完整的,原理就是这样 attn_output_weights = torch.bmm(q, k.transpose(1, 2)) #[8,100,32]@[8,32,h*w] = [8,100,h*w] assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_output_weights.masked_fill_(attn_mask, float('-inf')) else: attn_output_weights += attn_mask if key_padding_mask is not None: # attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) #[b,8,100,h*w] attn_output_weights = attn_output_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ) attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) attn_output_weights = softmax( # softmax attn_output_weights, dim=-1) attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) #使用了dropout attn_output = torch.bmm(attn_output_weights, v) #[8,100,h*w]@[8,h*w,32] = [8,100,32] assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) # 最后转换维度展平成[100,1,256] attn_output = linear(attn_output, out_proj_weight, out_proj_bias)# 后面还要经过个线性层,计算y = xA^T + b if need_weights: # average attention weights over heads attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) return attn_output, attn_output_weights.sum(dim=1) / num_heads
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。