赞
踩
在pytorch1.10版本当中自带的multiheadattention的源代码如下。基本思路和上一篇哈佛实现transformer中的多头注意力机制差不多。也就是把hidden_size先除以head_num再分别做self_attention,最后concat在一起。
需要理解的就是这当中的key_padding_mask以及attn_mask。
class MultiheadAttention(nn.Module): __constants__ = ['batch_first'] bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super(MultiheadAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.num_heads = num_heads self.dropout = dropout self.batch_first = batch_first self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) else: self.register_parameter('in_proj_bias', None) self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs) if add_bias_kv: self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.) constant_(self.out_proj.bias, 0.) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def __setstate__(self, state): # Support loading old MultiheadAttention checkpoints generated by v1.1.0 if '_qkv_same_embed_dim' not in state: state['_qkv_same_embed_dim'] = True super(MultiheadAttention, self).__setstate__(state) def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: if self.batch_first: query, key, value = [x.transpose(1, 0) for x in (query, key, value)] if not self._qkv_same_embed_dim: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask, use_separate_proj_weight=True, q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, v_proj_weight=self.v_proj_weight) else: attn_output, attn_output_weights = F.multi_head_attention_forward( query, key, value, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout, self.out_proj.weight, self.out_proj.bias, training=self.training, key_padding_mask=key_padding_mask, need_weights=need_weights, attn_mask=attn_mask) if self.batch_first: return attn_output.transpose(1, 0), attn_output_weights else: return attn_output, attn_output_weights
按照之前的理解来说,attn_mask的维度应该是(batch_size,seq_len,seq_len),但是在跑代码的时候却发现这样输入是有问题的,他总是会抛出同一个错误:
这个问题说明,在高版本的pytorch当中,attn_mask的输入不再是(batch_size,seq_len,seq_len)而应该变成(batch_size*head_num,seq_len,seq_len)。意思就是要在attn_mask构造时将每个头的mask都要考虑上,原代码默认每个头的attn_mask是不同的。
和attn_mask不一样,这个参数的目的在于限制padding的影响。如果input_id经过padding之后的维度是(2,5),那么key_padding_mask的维度也是(2,5)。
例子如下。如果一个句子utter2的token长度为3,padding之后的长度为5,那么对应的key_mask就是[0,0,0,1,1]。
特别注意: 一般来说,在论文中(也就是我之前提到过的哈佛复现的transformer中)对mask的理解为True位置参与attention,但是在pytorch实现中,官方文档说明:bool类型时,True位置不参与attention,False位置参与attention。byte类型时,非0位置不参与attention,0位置参与attention。
utter1=torch.rand(5,10)
utter2=torch.rand(3,10)
input=pad_sequence([utter1,utter2],padding_value=0,batch_first=True)
key_mask1=torch.tensor([0,0,0,0,0])
key_mask2=torch.tensor([0,0,0,1,1])
key_mask=torch.stack([key_mask1,key_mask2])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。