赞
踩
逐行注释,逐行解析。可直接运行。
code from https://github.com/graykode/nlp-tutorial/tree/master/5-1.Transformer
- import numpy as np
- import torch
- import torch.nn as nn
- import math
- import time
- import torch.optim as optim
- from torch.utils.data import Dataset,DataLoader
-
-
- # 13. MyDataset
- class MyDataset(Dataset):
- # 读数据
- def __init__(self, enc_inputs, dec_inputs, target_batch):
- self.enc_inputs = enc_inputs
- self.dec_inputs = dec_inputs
- self.target_batch = target_batch
-
- # 返回数据长度(有几行数据)
- def __len__(self):
- return len(self.enc_inputs)
- # return self.enc_inputs.shape[0]
-
- # 返回相对位置上的元素,会比make_batch函数返回的tensor数据少一个维度
- def __getitem__(self, idx):
- return self.enc_inputs[idx], self.dec_inputs[idx], self.target_batch[idx]
-
-
- # 12. make_batch
- def make_batch(sentences):
- input_batch = [[src_vocab[n] for n in sentences[0].split()]] # [[1, 2, 3, 4, 0]]
- output_batch = [[tgt_vocab[n] for n in sentences[1].split()]] # [[5, 1, 2, 3, 4]]
- target_batch = [[tgt_vocab[n] for n in sentences[2].split()]] # [[1, 2, 3, 4, 6]]
- return torch.LongTensor(input_batch), torch.LongTensor(output_batch), torch.LongTensor(target_batch)
-
-
- # 11. get_attn_subsequent_mask
- def get_attn_subsequent_mask(seq):
- attn_shape = [seq.size(0), seq.size(1), seq.size(1)] # [1, 5, 5]
- subsequence_mask = np.triu(np.ones(attn_shape), k=1) # ndarray [1, 5, 5]
- # .byte() is equivalent to self.to(torch.uint8)
- subsequence_mask = torch.from_numpy(subsequence_mask).byte() # [1, 5, 5]
- return subsequence_mask
-
-
- # 10. DecoderLayer:包含三个部分,带Mask的多头自注意力层、交互注意力层、前馈神经网络
- class DecoderLayer(nn.Module):
- def __init__(self):
- super(DecoderLayer, self).__init__()
- self.dec_self_attn = MultiHeadAttention()
- self.dec_enc_attn = MultiHeadAttention()
- self.pos_fnn = PoswiseFeedForwardNet()
-
- # [1, 5, 512] [1, 5, 512] [1, 5, 5] [1, 5, 5]
- def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
- # dec_self_attn===[1, 8, 5, 5] dec_outputs===[1, 5, 512]
- dec_outputs, dec_self_attn = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
- # dec_enc_attn===[1, 8, 5, 5] dec_outputs===[1, 5, 512]
- dec_outputs, dec_enc_attn = self.dec_enc_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
- dec_outputs = self.pos_fnn(dec_outputs) # [1, 5, 512]
- return dec_outputs, dec_self_attn, dec_enc_attn
-
-
- # 9. Decoder包含三个部分:词向量Embedding、位置编码部分、(带Masked自注意力层、交互注意力层、前馈神经网络)
- class Decoder(nn.Module):
- def __init__(self):
- super(Decoder, self).__init__()
- self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model) # [7, 512]
- self.pos_emb = PositionalEncoding(d_model)
- self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)]) # 10.
-
- def forward(self, dec_inputs, enc_inputs, enc_outputs):
- dec_outputs = self.tgt_emb(dec_inputs) # [1, 5, 512]
- dec_outputs = self.pos_emb(dec_outputs.transpose(0, 1)).transpose(0, 1) # [1, 5, 512]--->[5, 1, 512]--->[1, 5, 512]
- # 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/352545推荐阅读
相关标签
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。