当前位置:   article > 正文

transformer入门 注意力机制_transformer的注意力机制

transformer的注意力机制

目录

注意力原理:

注意力一句总结

简单模型和推理例子

qkv入门例子


这个讲的还可以:

https://zhuanlan.zhihu.com/p/420820453

注意力原理

Transformer 模型的注意力机制是其核心部分之一,它使得模型能够在处理序列数据时实现高效的关联性学习。它基于自注意力机制(Self-Attention),允许模型在不同位置的单词之间建立关联,从而更好地理解整个序列的语义。

注意力机制允许模型同时考虑输入序列中不同位置的信息,并计算出每个位置对于其他位置的重要性。它包含三个主要步骤:

  1. Query、Key、Value

    • 对于输入序列中的每个单词(或标记),通过三个线性变换(权重矩阵),分别得到 Query、Key 和 Value。
    • Query、Key 和 Value 是用来衡量每个单词在注意力机制中的重要性和影响力的向量表示。
  2. 计算注意力权重

    • 通过计算 Query 和所有 Key 之间的相似度(通常使用点积或其他方法),并应用 Softmax 函数获得每个单词对其他单词的注意力分数。
    • 这些分数决定了每个单词应该关注哪些位置的信息。
  3. 加权求和

    • 将每个 Value 根据对应的注意力权重进行加权求和,得到该位置的最终表示。
    • 这个求和过程能够使得模型更加关注与当前位置相关的信息。

这种注意力机制的优点在于,它允许模型对输入序列中的不同部分进行灵活的关注,而不是依赖于固定的滑动窗口或其他类似的固定模式。Transformer 中多头注意力(Multi-Head Attention)机制则是对这种注意力机制的扩展,允许模型同时关注来自不同表示空间的信息。

总的来说,Transformer 模型通过自注意力机制实现了对序列数据的建模,使得模型能够更好地理解序列中不同部分之间的关系,并在各种自然语言处理任务中取得了很好的效果。

注意力一句总结

它基于自注意力机制(Self-Attention),允许模型在不同位置的单词之间建立关联,从而更好地理解整个序列的语义。

自注意力机制允许模型同时考虑输入序列中不同位置的信息,并计算出每个位置对于其他位置的重要性。

简单模型和推理例子

  1. import time
  2. import numpy as np
  3. import torch
  4. from torch import nn
  5. # 定义Transformer模型
  6. class TimeSeriesTransformer(nn.Module):
  7. def __init__(self, input_size, num_layers, num_heads, dropout=0.1):
  8. super(TimeSeriesTransformer, self).__init__()
  9. self.model_type = 'Transformer'
  10. self.src_mask = None
  11. self.pos_encoder = PositionalEncoding(input_size, dropout)
  12. self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_size, nhead=num_heads, dropout=dropout)
  13. self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
  14. self.decoder = nn.Linear(288, 32)
  15. def forward(self, src):
  16. src = self.pos_encoder(src)
  17. output = self.transformer_encoder(src, self.src_mask)
  18. output=output.reshape(output.size(0), -1)
  19. output = self.decoder(output)
  20. return output.reshape(-1, 4, 8)
  21. # 位置编码
  22. class PositionalEncoding(nn.Module):
  23. def __init__(self, input_size, dropout=0.1, max_len=5000):
  24. super(PositionalEncoding, self).__init__()
  25. self.dropout = nn.Dropout(p=dropout)
  26. pe = torch.zeros(max_len, input_size)
  27. position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
  28. # 生成正弦用的div_term,对于维度为9的输入,我们需要5个正弦值
  29. div_term_even = torch.exp(torch.arange(0, input_size, 2).float() * (-np.log(10000.0) / input_size))
  30. # 生成余弦用的div_term,对于维度为9的输入,我们需要4个余弦值
  31. div_term_odd = torch.exp(torch.arange(1, input_size, 2).float() * (-np.log(10000.0) / input_size))
  32. # 正弦赋值,对于维度为9的输入,我们应该赋值给索引0, 2, 4, 6, 8
  33. pe[:, 0::2] = torch.sin(position * div_term_even.unsqueeze(0))
  34. # 余弦赋值,对于维度为9的输入,我们应该赋值给索引1, 3, 5, 7
  35. pe[:, 1::2] = torch.cos(position * div_term_odd.unsqueeze(0))
  36. pe = pe.unsqueeze(0).transpose(0, 1)
  37. self.register_buffer('pe', pe)
  38. def forward(self, x):
  39. x = x + self.pe[:x.size(0), :]
  40. return self.dropout(x)
  41. if __name__ == '__main__':
  42. net = TimeSeriesTransformer(input_size=9, num_layers=10, num_heads=3)
  43. for i in range(10):
  44. data = torch.rand(16, 32,9)
  45. start = time.time()
  46. out = net(data)
  47. print('time', time.time() - start, out.size())

qkv入门例子

  1. import numpy as np
  2. import torch
  3. import torch.nn.functional as F
  4. # 假设我们有两个向量,每个向量的维度为 5
  5. x = np.array([[1, 0, 1, 0, 1], [0, 1, 0, 1, 0]], dtype=np.float32)
  6. # 将 numpy 数组转换为 PyTorch 张量
  7. x = torch.tensor(x)
  8. # 初始化权重矩阵 Wq, Wk, Wv
  9. d_model = 5 # 向量维度
  10. Wq = torch.randn(d_model, d_model)
  11. Wk = torch.randn(d_model, d_model)
  12. Wv = torch.randn(d_model, d_model)
  13. # 计算 Q, K, V
  14. Q = torch.matmul(x, Wq)
  15. K = torch.matmul(x, Wk)
  16. V = torch.matmul(x, Wv)
  17. # 计算注意力得分
  18. scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_model)
  19. # 应用 softmax 函数来获取注意力权重
  20. attention_weights = F.softmax(scores, dim=-1)
  21. # 计算注意力结果
  22. attention_output = torch.matmul(attention_weights, V)
  23. print("Attention Weights:")
  24. print(attention_weights)
  25. print("Attention Output:")
  26. print(attention_output)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/906620
推荐阅读
相关标签
  

闽ICP备14008679号