当前位置:   article > 正文

Transformer编码器-解码器(Encoder-Decoder)架构介绍+代码实现_transformer编码器和解码器

transformer编码器和解码器

一,Transformer架构介绍:

      Transformer的编码器-解码器是基于自注意力的模块叠加而成的,源序列(Input)和目标序列(Target)的嵌入(Embedding)表示将加上位置编码(Positional encoding),再分别输入到编码器和解码器中。从宏观角度来看,Transformer的编码器是由多个相同的层叠加而 成的,每个层都有两个子层(子层表示为sublayer)。第一个子层是多头自注意力(multi‐head self‐attention

汇聚;第二个子层是基于位置的前馈网络( positionwise feed‐forward network )。具体来说,在计算编码器 的自注意力时,查询、键和值都来自前一个编码器层的输出, 每个子层都采用
了残差连接( residual connection )。在 Transformer 中,对于序列中任何位置的任何输入 x R d ,都要求满 足sublayer ( x ) R d ,以便残差连接满足 x + sublayer ( x ) R d。在残差连接的加法计算之后,紧接着应用层 规范化(layer normalization)。因此,输入序列对应的每个位置,Transformer编码器都将输出一个 d维表示向量。 Transformer解码器也是由多个相同的层叠加而成的,并且层中使用了残差连接和层规范化。除了编码器中 描述的两个子层之外,解码器还在这两个子层之间插入了第三个子层,称为编码器-解码器注意力(encoder‐ decoder attention)层。在编码器-解码器注意力中,查询来自前一个解码器层的输出,而键和值来自整个 编码器的输出。在解码器自注意力中,查询、键和值都来自上一个解码器层的输出。但是,解码器中的每个位 置只能考虑该位置之前的所有位置。这种掩蔽(masked)注意力保留了自回归(auto‐regressive)属性,确 保预测仅依赖于已生成的输出词元。

二,嵌入层和位置编码的作用:

举个例子,假设我们正在处理一个英语到西班牙语的翻译问题,其中一个样本的源序列是 "The ball is blue",目标序列是 "La bola es azul"。

源序列首先通过 Embedding 和 Position Encoding 层,为序列中的每个单词生成嵌入。随后嵌入被传递到编码器,到达 Attention module.

Attention module 中,嵌入的序列通过三个线性层(Linear layers),产生三个独立的矩阵--Query、Key、Value。这三个矩阵被用来计算注意力得分。这些矩阵的每一 "行 "对应于源序列中的一个词。

位置编码的简单图示:

三,多头自注意力网络(MultiHead Self-Attention):

多头点积注意力:

Query 与 Key 的转置进行点积,产生一个中间矩阵,即所谓“因子矩阵”。因子矩阵的每个单元都是两个词向量之间的矩阵乘法。

如下所示,因子矩阵第 4 行的每一列都对应于 Q4 向量与每个 K 向量之间的点积;因子矩阵的第 2 列对应与每个 Q 向量与 K2 向量之间的点积,这个“因子矩阵”就是注意力分数(Attention Score)

因子矩阵再经过Softmax函数生成一组概率分布也就是注意力权重(Attention Weights)再和 V 矩阵之间进行矩阵相乘,产生注意力池化(Attention Pooling)输出。可以看到,输出矩阵中第 4 行对应的是 Q4 矩阵与所有其他对应的 K 和 V 相乘:

可以将注意力得分理解成一个词的“编码值”。这个编码值是由“因子矩阵”对 Value 矩阵中的词加权而来。而“因子矩阵”中对应的权值则是该特定单词的 Query 向量与 Key 向量的点积。再啰嗦一遍:

1-一个词的注意力得分可以理解为该词的"编码值",它是注意力机制最终为每个词赋予的表示向量。

2-这个"编码值"是由"值矩阵"(Value矩阵)中每个词的值向量加权求和得到的。

3-加权的权重就是"因子矩阵"中对应的注意力权重。

4-"因子矩阵"中的注意力权重是通过该词的查询向量(Query)与所有词的键向量(Key)做点积计算得到的。

四,Transformer的Encoder与Decoder键值传递:

在 "Encoder-Decoder Attention "中,Query 来自Traget序列,而Key-Value来自Input序列。

五,Transformer的实战代码:

  1. import math
  2. import pandas as pd
  3. import torch
  4. from torch import nn
  5. from d2l import torch as d2l
  6. import matplotlib.pyplot as plt
  7. # 1.1, 搭建多头注意力
  8. def masked_softmax(X, valid_lens):
  9. """通过在最后一个轴上掩蔽元素来执行softmax操作"""
  10. # X:3D张量,valid_lens:1D或2D张量
  11. if valid_lens is None:
  12. return nn.functional.softmax(X, dim=-1)
  13. else:
  14. shape = X.shape
  15. if valid_lens.dim() == 1:
  16. valid_lens = torch.repeat_interleave(valid_lens, shape[1])
  17. else:
  18. valid_lens = valid_lens.reshape(-1)
  19. # 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
  20. X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
  21. return nn.functional.softmax(X.reshape(shape), dim=-1)
  22. class DotProductAttention(nn.Module):
  23. """缩放点积注意力"""
  24. def __init__(self, dropout, **kwargs):
  25. super(DotProductAttention, self).__init__(**kwargs)
  26. self.dropout = nn.Dropout(dropout)
  27. # queries的形状:(batch_size,查询的个数,d)
  28. # keys的形状:(batch_size,“键-值”对的个数,d)
  29. # values的形状:(batch_size,“键-值”对的个数,值的维度)
  30. # valid_lens的形状:(batch_size,)或者(batch_size,查询的个数)
  31. def forward(self, queries, keys, values, valid_lens=None):
  32. d = queries.shape[2]
  33. scores = torch.bmm(queries, keys.transpose(1, 2))/math.sqrt(d)
  34. self.attention_weights = masked_softmax(scores, valid_lens)
  35. return torch.bmm(self.dropout(self.attention_weights), values)
  36. def transpose_qkv(tensor, num_heads):
  37. """为了多注意力头的并行计算而变换形状"""
  38. # 输入tensor的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
  39. # 输出tensor的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
  40. # num_hiddens/num_heads)
  41. tensor = tensor.reshape(tensor.shape[0], tensor.shape[1], num_heads, -1)
  42. # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
  43. # num_hiddens/num_heads)
  44. tensor = tensor.permute(0, 2, 1, 3)
  45. # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
  46. # num_hiddens/num_heads)
  47. return tensor.reshape(-1, tensor.shape[2], tensor.shape[3])
  48. def transpose_output(tensor, num_heads):
  49. """逆转transpose_qkv函数的操作"""
  50. tensor = tensor.reshape(-1, num_heads, tensor.shape[1], tensor.shape[2])
  51. tensor = tensor.permute(0, 2, 1, 3)
  52. return tensor.reshape(tensor.shape[0], tensor.shape[1], -1)
  53. class MultiHeadAttention(nn.Module):
  54. def __init__(self, key_size, query_size, value_size, num_hiddens, num_heads, dropout, bias=False, **kwargs):
  55. super(MultiHeadAttention, self).__init__(**kwargs)
  56. print('MultiHeadAttention __init__:', file=log)
  57. self.num_heads = num_heads
  58. self.attention = DotProductAttention(dropout)
  59. self.weight_query = nn.Linear(query_size, num_hiddens, bias=bias)
  60. self.weight_key = nn.Linear(key_size, num_hiddens, bias=bias)
  61. self.weight_value = nn.Linear(value_size, num_hiddens, bias=bias)
  62. self.weight_output = nn.Linear(num_hiddens, num_hiddens, bias=bias)
  63. def forward(self, queries, keys, values, valid_lens):
  64. # queries, keys, values的形状:
  65. # (batch_size, 查询或者“键-值”对的个数,num_hiddens)
  66. # valid_lens 的形状:
  67. # (batch_size, )或(batch_size, 查询的个数)
  68. # 经过变换后,输出的queries, keys, values的形状
  69. # (batch_size*num_heads, 查询或者“键-值”对的个数)
  70. # num_hiddens/num_heads
  71. queries = transpose_qkv(self.weight_query(queries), self.num_heads)
  72. # print('queries:', file=log)
  73. # print(queries.shape, file=log)
  74. # print(queries, file=log)
  75. keys = transpose_qkv(self.weight_key(keys), self.num_heads)
  76. # print('keys:', file=log)
  77. # print(keys.shape, file=log)
  78. # print(keys, file=log)
  79. values = transpose_qkv(self.weight_value(values), self.num_heads)
  80. # print('values:', file=log)
  81. # print(values.shape, file=log)
  82. # print(values, file=log)
  83. if valid_lens is not None:
  84. # 在轴0,将第一项(标量或者矢量)复制num_heads次,
  85. # 然后如此复制第二项,然后诸如此类。
  86. valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)
  87. # print('valid_lens:', file=log)
  88. # print(valid_lens.shape, file=log)
  89. # print(valid_lens, file=log)
  90. # output的形状:(batch_size*num_heads,查询的个数,num_hiddens/num_heads)
  91. output = self.attention(queries, keys, values, valid_lens)
  92. # print('output:', file=log)
  93. # print(output.shape, file=log)
  94. # print(output, file=log)
  95. # output_concat的形状:(batch_size,查询的个数,num_hiddens)
  96. output_concat = transpose_output(output, self.num_heads)
  97. # print('output_concat:', file=log)
  98. # print(output_concat.shape, file=log)
  99. # print(output_concat, file=log)
  100. return self.weight_output(output_concat)
  101. # 2.1, 基于位置的前馈网络
  102. class PositionWiseFFN(nn.Module):
  103. def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
  104. super(PositionWiseFFN, self).__init__(**kwargs)
  105. print('PositionWiseFFN __init__:', file=log)
  106. self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
  107. self.relu = nn.ReLU()
  108. self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)
  109. def forward(self, X):
  110. return self.dense2(self.relu(self.dense1(X)))
  111. """
  112. FFN = PositionWiseFFN(4, 4, 8)
  113. FFN.eval()
  114. print(FFN(torch.ones((2, 3, 4))).shape)
  115. print(FFN(torch.ones((2, 3, 4))))
  116. """
  117. # 3.1, 残差连接和层规范化
  118. ln = torch.nn.LayerNorm(2)
  119. bn = torch.nn.BatchNorm1d(2)
  120. X = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
  121. # 在训练模式下计算x的均值和方差
  122. print('layer norm:', ln(X), '\nbatch norm:', bn(X))
  123. # 3.2, 使用残差连接和层规范化来实现AddNorm类,暂退法也被作为正则化方法使用。
  124. class AddNorm(nn.Module):
  125. def __init__(self, normalized_shape, dropout, **kwargs):
  126. super(AddNorm, self).__init__(**kwargs)
  127. print('AddNorm __init__:', file=log)
  128. self.dropout = torch.nn.Dropout(dropout)
  129. self.ln = torch.nn.LayerNorm(normalized_shape)
  130. def forward(self, X, Y):
  131. return self.ln(self.dropout(Y) + X)
  132. # 3.3 残差连接要求两个输入的形状相同,以便加法操作后输出张量的形状相同。
  133. """
  134. add_norm = AddNorm([3, 4], 0.5)
  135. add_norm.eval()
  136. tensor1 = add_norm(torch.ones((2, 3, 4)), torch.ones((2, 3, 4)))
  137. print(tensor1.shape)
  138. print(tensor1)
  139. """
  140. # 4.1, 位置编码
  141. class PositionalEncoding(nn.Module):
  142. """位置编码"""
  143. def __init__(self, num_hiddens, dropout, max_len=1000):
  144. super(PositionalEncoding, self).__init__()
  145. self.dropout = nn.Dropout(dropout)
  146. # 创建一个足够长的P
  147. self.P = torch.zeros((1, max_len, num_hiddens))
  148. print('PositionalEncoding:', file=log)
  149. print(self.P.shape, file=log)
  150. X = torch.arange(0, max_len, step=1, dtype=torch.float32).reshape(-1, 1)
  151. print('X1:', file=log)
  152. print(X.shape, file=log)
  153. Y = torch.pow(10000, torch.arange(0, num_hiddens, step=2, dtype=torch.float32) / num_hiddens)
  154. print('Y:', file=log)
  155. print(Y.shape, file=log)
  156. print(Y, file=log)
  157. X = X/Y
  158. print('X3:', file=log)
  159. print(X.shape, file=log)
  160. print(X, file=log)
  161. self.P[:, :, 0:num_hiddens:2] = torch.sin(X)
  162. self.P[:, :, 1:num_hiddens:2] = torch.cos(X)
  163. print(self.P.shape, file=log)
  164. print(self.P, file=log)
  165. def forward(self, X):
  166. X = X + self.P[:, :X.shape[1], :].to(X.device)
  167. print('self.P:', file=log)
  168. print(self.P[:, :X.shape[1], :].to(X.device), file=log)
  169. print('X4:', file=log)
  170. print(X.shape, file=log)
  171. print(X, file=log)
  172. return self.dropout(X)
  173. # 4.2, 位置编码测试
  174. """
  175. encoding_dim, num_steps = 32, 60
  176. pos_encoding = PositionalEncoding(encoding_dim, 0)
  177. pos_encoding.eval()
  178. tensor_pos = torch.zeros((1, num_steps, encoding_dim))
  179. print('tensor_pos:')
  180. print(tensor_pos.shape)
  181. print(tensor_pos)
  182. X = pos_encoding(tensor_pos)
  183. print('X4:')
  184. print(X.shape)
  185. print(X)
  186. P = pos_encoding.P[:, :X.shape[1], :]
  187. d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row(position)', figsize=(6, 2.5), legend=["Col %d" % d for d in torch.arange(6, 10)])
  188. plt.show()
  189. """
  190. # 5.1, 实现编码器中的一个层,EncoderBlock类包含两个子层:多头自注意力和基于位置的前馈网络,
  191. # 这两个子层都使用了残差连接和紧随的层规范化。
  192. class EncoderBlock(nn.Module):
  193. def __init__(self, key_size, query_size, value_size, num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens, num_heads, dropout, use_bias=False, **kwargs):
  194. super(EncoderBlock, self).__init__(**kwargs)
  195. print('EncoderBlock __init__:', file=log)
  196. self.attention = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout, use_bias)
  197. self.addnorm1 = AddNorm(norm_shape, dropout)
  198. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
  199. self.addnorm2 = AddNorm(norm_shape, dropout)
  200. def forward(self, X, valid_lens):
  201. print('EncoderBlock forward:', file=log)
  202. atten = self.attention(X, X, X, valid_lens)
  203. print('atten:', file=log)
  204. print(atten.shape, file=log)
  205. print(atten, file=log)
  206. Y = self.addnorm1(X, atten)
  207. print('Y:', file=log)
  208. print(Y.shape, file=log)
  209. print(Y, file=log)
  210. ff = self.ffn(Y)
  211. print('ff:', file=log)
  212. print(ff.shape, file=log)
  213. print(ff, file=log)
  214. return self.addnorm2(Y, ff)
  215. """
  216. log = open('transformer1.txt', mode='a', encoding='utf-8')
  217. # X = torch.ones((2, 100, 24))
  218. valid_lens = torch.tensor([3, 2])
  219. encoder_blk = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
  220. encoder_blk.eval()
  221. # encoder_blk(X, valid_lens).shape
  222. """
  223. # 5.2, 实现Transformer编码器,堆叠num_layers个EncoderBlock类的实例。
  224. class TransformerEncoder(d2l.Encoder):
  225. def __init__(self, vocab_size, key_size, query_size, value_size,
  226. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  227. num_heads, num_layers, dropout, use_bias=False, **kwargs):
  228. super(TransformerEncoder, self).__init__(**kwargs)
  229. self.num_hiddens = num_hiddens
  230. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  231. self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
  232. self.blks = nn.Sequential()
  233. for i in range(num_layers):
  234. print('EncoderBlock instance index:', file=log)
  235. print(i, file=log)
  236. self.blks.add_module("block"+str(i), EncoderBlock(key_size, query_size, value_size,
  237. num_hiddens, norm_shape, ffn_num_input,
  238. ffn_num_hiddens, num_heads, dropout, use_bias))
  239. def forward(self, X, valid_lens, *args):
  240. # 因为位置编码值在-1和1之间,
  241. # 因此嵌入值乘以嵌入维度的平方根进行缩放,
  242. # 然后再与位置编码相加。
  243. X = self.embedding(X) * math.sqrt(self.num_hiddens)
  244. print('TransformerEncoder:', file=log)
  245. print(X.shape, file=log)
  246. print(X, file=log)
  247. X = self.pos_encoding(X)
  248. print('TransformerEncoder:', file=log)
  249. print(X.shape, file=log)
  250. print(X, file=log)
  251. self.attention_weights = [None]*len(self.blks)
  252. for i, blk in enumerate(self.blks):
  253. X = blk(X, valid_lens)
  254. self.attention_weights[i] = blk.attention.attention.attention_weights
  255. return X
  256. # 6.1, 指定超参数创建两层的Transformer编码器,Transformer编码器输出的形状是(批量大小,时间步数目,num_hiddens)
  257. """
  258. encoder = TransformerEncoder(200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
  259. encoder.eval()
  260. valid_lens = torch.tensor([3,2])
  261. encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape
  262. """
  263. # 7.1, 解码器包含了三个子层:解码器自注意力、“编码器-解码器”注意力和基于位置的前馈网络。
  264. class DecoderBlock(nn.Module):
  265. """解码器中第i个块"""
  266. def __init__(self, key_size, query_size, value_size, num_hiddens,
  267. norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
  268. dropout, i, **kwargs):
  269. super(DecoderBlock, self).__init__(**kwargs)
  270. print('DecoderBlock __init__:', file=log)
  271. self.i = i
  272. self.attention1 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  273. self.addnorm1 = AddNorm(norm_shape, dropout)
  274. self.attention2 = MultiHeadAttention(key_size, query_size, value_size, num_hiddens, num_heads, dropout)
  275. self.addnorm2 = AddNorm(norm_shape, dropout)
  276. self.ffn = PositionWiseFFN(ffn_num_input, ffn_num_hiddens, num_hiddens)
  277. self.addnorm3 = AddNorm(norm_shape, dropout)
  278. def forward(self, X, state):
  279. enc_outputs, enc_valid_lens = state[0], state[1]
  280. # 训练阶段,输出序列的所有词元都在同一时间处理,
  281. # 因此state[2][self.i]初始化为None。
  282. # 预测阶段,输出序列是通过词元一个接着一个解码的,
  283. # 因此state[2][self.i]包含着直到当前时间步第i个块解码的输出表示
  284. print('DecoderBlock forward:', file=log)
  285. if state[2][self.i] is None:
  286. key_values = X
  287. else:
  288. key_values = torch.cat((state[2][self.i], X), axis=1)
  289. state[2][self.i] = key_values
  290. if self.training:
  291. batch_size, num_steps, _ = X.shape
  292. # dec_valid_lens的开头:(batch_size,num_steps),
  293. # 其中每一行是[1,2,...,num_steps]
  294. dec_valid_lens = torch.arange(1, num_steps + 1, device=X.device).repeat(batch_size, 1)
  295. else:
  296. dec_valid_lens = None
  297. # 自注意力
  298. X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
  299. Y = self.addnorm1(X, X2)
  300. # 编码器-解码器注意力。
  301. # enc_outputs的开头:(batch_size,num_steps,num_hiddens)
  302. Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
  303. Z = self.addnorm2(Y, Y2)
  304. return self.addnorm3(Z, self.ffn(Z)), state
  305. class TransformerDecoder(d2l.AttentionDecoder):
  306. def __init__(self, vocab_size, key_size, query_size, value_size,
  307. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  308. num_heads, num_layers, dropout, **kwargs):
  309. super(TransformerDecoder, self).__init__(**kwargs)
  310. self.num_hiddens = num_hiddens
  311. self.num_layers = num_layers
  312. self.embedding = nn.Embedding(vocab_size, num_hiddens)
  313. self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
  314. self.blks = nn.Sequential()
  315. for i in range(num_layers):
  316. self.blks.add_module("block" + str(i),
  317. DecoderBlock(key_size, query_size, value_size, num_hiddens,
  318. norm_shape, ffn_num_input, ffn_num_hiddens,
  319. num_heads, dropout, i))
  320. self.dense = nn.Linear(num_hiddens, vocab_size)
  321. def init_state(self, enc_outputs, enc_valid_lens, *args):
  322. return [enc_outputs, enc_valid_lens, [None] * self.num_layers]
  323. def forward(self, X, state):
  324. X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
  325. self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
  326. for i, blk in enumerate(self.blks):
  327. X, state = blk(X, state)
  328. # 解码器自注意力权重
  329. self._attention_weights[0][i] = blk.attention1.attention.attention_weights
  330. # “编码器-解码器”自注意力权重
  331. self._attention_weights[1][i] = blk.attention2.attention.attention_weights
  332. return self.dense(X), state
  333. @property
  334. def attention_weights(self):
  335. return self._attention_weights
  336. # 8.1, 训练
  337. num_hiddens, num_layers, dropout, batch_size, num_steps = 32, 2, 0.1, 64, 10
  338. lr, num_epochs, device = 0.005, 200, d2l.try_gpu()
  339. ffn_num_input, ffn_num_hiddens, num_heads = 32, 64, 4
  340. key_size, query_size, value_size = 32, 32, 32
  341. norm_shape = [32]
  342. log = open('transformer1.txt', mode='a', encoding='utf-8')
  343. # 8.2, 加载数据集
  344. train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
  345. encoder = TransformerEncoder(len(src_vocab), key_size, query_size, value_size,
  346. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  347. num_heads, num_layers, dropout)
  348. decoder = TransformerDecoder(len(tgt_vocab), key_size, query_size, value_size,
  349. num_hiddens, norm_shape, ffn_num_input, ffn_num_hiddens,
  350. num_heads, num_layers, dropout)
  351. net = d2l.EncoderDecoder(encoder, decoder)
  352. d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
  353. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/828408
推荐阅读
相关标签
  

闽ICP备14008679号