当前位置:   article > 正文

pytorch 中 多头注意力机制 MultiHeadAttention的代码实现及应用_多头注意力机制代码

多头注意力机制代码

本文将对 Scaled Dot-Product Attention,Multi-head attentionSelf-attentionTransformer等概念做一个简要介绍和区分。最后对通用的 Multi-head attention 进行代码实现和应用

一、概念:

1. Scaled Dot-Product Attention

在实际应用中,经常会用到 Attention 机制,其中最常用的是Scaled Dot-Product Attention,它是通过计算query和key之间的点积 来作为 之间的相似度。

  • Scaled 指的是 Q和K计算得到的相似度 再经过了一定的量化,具体就是 除以 根号下K_dim;
  • Dot-Product 指的是 Q和K之间 通过计算点积作为相似度;
  • Mask 可选择性 目的是将 padding的部分 填充负无穷,这样算softmax的时候这里就attention为0,从而避免padding带来的影响.

2. Multi-head attention

是在 Scaled Dot-Product Attention 的基础上,分成多个头,也就是有多个Q、K、V并行进行计算attention,可能侧重与不同的方面的相似度和权重。

3. Self-attention

自注意力机制 是在Scaled Dot-Product Attention 以及Multi-head attention的基础上的一种应用场景,就是指 QKV的来源是相同的自己和自己计算attention,类似于经过一个线性层等,输入输出等长。

如果QKV的来源是不同的,不能叫做 self-attention,只能是attention。比如GST中的KV是随机初始化的多个token,而Q是reference encoder得到的梅尔谱的一帧。同理,Q也可以是随机初始化的一个,而KV是来自于输入,这样就可以将某一变长长度为N的输入计算attention得到一个长度为1的向量。

4. Transformer

Transformer 是指 在Scaled Dot-Product Attention 以及Multi-head attention以及Self-attention的基础上的一种通用的模型框架,它包括Positional Encoding,Encoder,Decoder等等。Transformer不等于Self-attention。

二、代码实现

 平时经常会用到Attention操作,接下来对Multi-head Attention 进行代码整理和实现,方便以后可以直接调用接口,其中单头注意力机制作为其中的一种特殊情况。

  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import torch.nn.functional as F
  5. class MultiHeadAttention(nn.Module):
  6. '''
  7. input:
  8. query --- [N, T_q, query_dim]
  9. key --- [N, T_k, key_dim]
  10. mask --- [N, T_k]
  11. output:
  12. out --- [N, T_q, num_units]
  13. scores -- [h, N, T_q, T_k]
  14. '''
  15. def __init__(self, query_dim, key_dim, num_units, num_heads):
  16. super().__init__()
  17. self.num_units = num_units
  18. self.num_heads = num_heads
  19. self.key_dim = key_dim
  20. self.W_query = nn.Linear(in_features=query_dim, out_features=num_units, bias=False)
  21. self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
  22. self.W_value = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
  23. def forward(self, query, key, mask=None):
  24. querys = self.W_query(query) # [N, T_q, num_units]
  25. keys = self.W_key(key) # [N, T_k, num_units]
  26. values = self.W_value(key)
  27. split_size = self.num_units // self.num_heads
  28. querys = torch.stack(torch.split(querys, split_size, dim=2), dim=0) # [h, N, T_q, num_units/h]
  29. keys = torch.stack(torch.split(keys, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
  30. values = torch.stack(torch.split(values, split_size, dim=2), dim=0) # [h, N, T_k, num_units/h]
  31. ## score = softmax(QK^T / (d_k ** 0.5))
  32. scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
  33. scores = scores / (self.key_dim ** 0.5)
  34. ## mask
  35. if mask is not None:
  36. ## mask: [N, T_k] --> [h, N, T_q, T_k]
  37. mask = mask.unsqueeze(1).unsqueeze(0).repeat(self.num_heads,1,querys.shape[2],1)
  38. scores = scores.masked_fill(mask, -np.inf)
  39. scores = F.softmax(scores, dim=3)
  40. ## out = score * V
  41. out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
  42. out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(0) # [N, T_q, num_units]
  43. return out,scores

 三、实际应用:

1. 接口调用:

  1. ## 类实例化
  2. attention = MultiHeadAttention(3,4,5,1)
  3. ## 输入
  4. qurry = torch.randn(8, 2, 3)
  5. key = torch.randn(8, 6 ,4)
  6. mask = torch.tensor([[False, False, False, False, True, True],
  7. [False, False, False, True, True, True],
  8. [False, False, False, False, True, True],
  9. [False, False, False, True, True, True],
  10. [False, False, False, False, True, True],
  11. [False, False, False, True, True, True],
  12. [False, False, False, False, True, True],
  13. [False, False, False, True, True, True],])
  14. ## 输出
  15. out, scores = attention(qurry, key, mask)
  16. print('out:', out.shape) ## torch.Size([8, 2, 5])
  17. print('scores:', scores.shape) ## torch.Size([1, 8, 2, 6])

2. mask的作用:

mask之前的 scores:

mask之后的 scores:

softmax之后的scores:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/347908
推荐阅读
相关标签
  

闽ICP备14008679号