当前位置:   article > 正文

人工智能(Pytorch)搭建模型5-注意力机制模型的构建与GRU模型融合应用_gru-selfattention pytorch代码

gru-selfattention pytorch代码

大家好,我是微学AI,今天给大家介绍一下人工智能(Pytorch)搭建模型5-注意力机制模型的构建与GRU模型融合应用。注意力机制是一种神经网络模型,在序列到序列的任务中,可以帮助解决输入序列较长时难以获取全局信息的问题。该模型通过对输入序列不同部分赋予不同的权重,以便在每个时间步骤上更好地关注需要处理的信息。在编码器-解码器(Encoder-Decoder)框架中,编码器将输入序列映射为一系列向量,而解码器则在每个时间步骤上生成输出序列。在此过程中,解码器需要对编码器的所有时刻进行“注意”,以了解哪些输入对当前时间步骤最重要。

在注意力机制中,解码器会计算每个编码器输出与当前解码器隐藏状态之间的相关度,并将其转化为注意力权重,以确定每个编码器输出对当前时刻解码器状态的贡献。这些权重被用于加权求和编码器输出,从而得到一个上下文向量,该向量包含有关输入序列的重要信息,有助于提高模型的性能和泛化能力。

b5c612de9ab741fa85176cf6af8c13a2.png

一、注意力机制模型构建

  1. # 1. 导入所需库
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import numpy as np
  6. from torch.utils.data import Dataset, DataLoader
  7. class Attention(nn.Module):
  8. def __init__(self, hidden_size):
  9. super(Attention, self).__init__()
  10. self.hidden_size = hidden_size
  11. self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
  12. self.v = nn.Linear(hidden_size, 1, bias=False)
  13. def forward(self, hidden, encoder_outputs):
  14. max_len = encoder_outputs.size(1)
  15. repeated_hidden = hidden.unsqueeze(1).repeat(1, max_len, 1)
  16. energy = torch.tanh(self.attn(torch.cat((repeated_hidden, encoder_outputs), dim=2)))
  17. attention_scores = self.v(energy).squeeze(2)
  18. attention_weights = nn.functional.softmax(attention_scores, dim=1)
  19. context_vector = (encoder_outputs * attention_weights.unsqueeze(2)).sum(dim=1)
  20. return context_vector, attention_weights

 以上Attention类是注意力机制的神经网络模型,该模型接收两个输入参数:隐藏状态编码器输出。其中,隐藏状态是解码器中上一个时间步骤的输出,而编码器输出是编码器模型对输入序列进行编码后的输出。编码器输出和隐藏状态被用于计算上下文向量注意力权重。通过将隐藏状态和编码器输出进行拼接,然后将结果通过线性层进行处理,并使用tanh激活函数后得到能量矩阵(energy)。接着,使用另一个线性层(self.v)将能量矩阵转换成注意力得分(attention scores),并使用softmax函数转换成注意力权重(attention weights)。最后,根据注意力权重对编码器输出进行加权组合得到上下文向量。

整个过程可以简单概括为:先将隐藏状态和编码器输出连接起来,然后使用线性转换和tanh激活函数计算能量矩阵,再使用线性转换和softmax函数计算注意力权重,最后使用注意力权重对编码器输出进行加权组合得到上下文向量。

二、GRU模型构建+注意力机制

  1. class GRUModel(nn.Module):
  2. def __init__(self, input_size, hidden_size, output_size, num_layers, dropout=0.5):
  3. super(GRUModel, self).__init__()
  4. self.hidden_size = hidden_size
  5. self.num_layers = num_layers
  6. self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
  7. self.attention = Attention(hidden_size)
  8. self.fc = nn.Linear(hidden_size, output_size)
  9. self.dropout = nn.Dropout(dropout)
  10. def forward(self, x):
  11. h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size)
  12. out, hidden = self.gru(x, h0)
  13. out, attention_weights = self.attention(hidden[-1], out)
  14. out = self.dropout(out)
  15. out = self.fc(out)
  16. return out

GRUModel类的初始化方法中,先调用了父类构造函数初始化。然后定义了一个GRU层,并将其输出传入Attention类中计算上下文向量和注意力权重。最后将上下文向量送入一个线性层,并加上dropout操作,为了防止过拟合现象,然后输出模型的预测结果。通过这个模型的设计,我们可以将输入序列和输出序列的长度变化对模型的性能影响降到最小,并且利用注意力机制使模型能够更好的关注序列中的重要信息。

三、数据生成与加载

  1. # 3. 准备数据集
  2. class SampleDataset(Dataset):
  3. def __init__(self):
  4. self.sequences = []
  5. self.labels = []
  6. for _ in range(1000):
  7. seq = torch.randn(10, 5)
  8. label = torch.zeros(2)
  9. if seq.sum() > 0:
  10. label[0] = 1
  11. else:
  12. label[1] = 1
  13. self.sequences.append(seq)
  14. self.labels.append(label)
  15. def __len__(self):
  16. return len(self.sequences)
  17. def __getitem__(self, idx):
  18. return self.sequences[idx], self.labels[idx]
  19. train_set_split = int(0.8 * len(SampleDataset()))
  20. train_set, test_set = torch.utils.data.random_split(SampleDataset(),
  21. [train_set_split, len(SampleDataset()) - train_set_split])
  22. train_loader = DataLoader(train_set, batch_size=32, shuffle=True)
  23. test_loader = DataLoader(test_set, batch_size=32, shuffle=False)

四、模型训练

  1. # 4. 定义训练过程
  2. def train(model, loader, criterion, optimizer, device):
  3. model.train()
  4. running_loss = 0.0
  5. correct = 0
  6. total = 0
  7. for batch_idx, (inputs, labels) in enumerate(loader):
  8. inputs, labels = inputs.to(device), labels.to(device)
  9. optimizer.zero_grad()
  10. outputs = model(inputs)
  11. loss = criterion(outputs, labels)
  12. loss.backward()
  13. optimizer.step()
  14. running_loss += loss.item()
  15. _, predicted = torch.max(outputs, 1)
  16. _, true_labels = torch.max(labels, 1)
  17. total += true_labels.size(0)
  18. correct += (predicted == true_labels).sum().item()
  19. print("Train Loss: {:.4f}, Acc: {:.2f}%".format(running_loss / (batch_idx + 1), 100 * correct / total))
  20. # 5. 定义评估过程
  21. def evaluate(model, loader, criterion, device):
  22. model.eval()
  23. running_loss = 0.0
  24. correct = 0
  25. total = 0
  26. with torch.no_grad():
  27. for batch_idx, (inputs, labels) in enumerate(loader):
  28. inputs, labels = inputs.to(device), labels.to(device)
  29. outputs = model(inputs)
  30. loss = criterion(outputs, labels)
  31. running_loss += loss.item()
  32. _, predicted = torch.max(outputs, 1)
  33. _, true_labels = torch.max(labels, 1)
  34. total += true_labels.size(0)
  35. correct += (predicted == true_labels).sum().item()
  36. print("Test Loss: {:.4f}, Acc: {:.2f}%".format(running_loss / (batch_idx + 1), 100 * correct / total))
  37. # 6. 训练模型并评估
  38. device = "cuda" if torch.cuda.is_available() else "cpu"
  39. model = GRUModel(input_size=5, hidden_size=10, output_size=2, num_layers=1).to(device)
  40. criterion = nn.BCEWithLogitsLoss()
  41. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  42. num_epochs = 100
  43. for epoch in range(num_epochs):
  44. print("Epoch {}/{}".format(epoch + 1, num_epochs))
  45. train(model, train_loader, criterion, optimizer, device)
  46. evaluate(model, test_loader, criterion, device)

运行结果:

  1. Epoch 97/100
  2. Train Loss: 0.0264, Acc: 99.75%
  3. Test Loss: 0.1267, Acc: 94.50%
  4. Epoch 98/100
  5. Train Loss: 0.0294, Acc: 99.75%
  6. Test Loss: 0.1314, Acc: 95.00%
  7. Epoch 99/100
  8. Train Loss: 0.0286, Acc: 99.75%
  9. Test Loss: 0.1280, Acc: 94.50%
  10. Epoch 100/100
  11. Train Loss: 0.0286, Acc: 99.75%
  12. Test Loss: 0.1324, Acc: 95.50%

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Guff_9hys/article/detail/824799
推荐阅读
相关标签
  

闽ICP备14008679号