赞
踩
版权声明:本文为原创文章,未经博主允许不得用于商业用途。
第一个RNN程序用来练手,输入上联,输出下联,使用了seq2seq模型,如下图
(Image source: https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/)
首先使用word-embedding对汉字重新编码到500维向量,之后经过encoderRNN和decoderRNN(双向GRU),其中decoderRNN通过Attention对encoder的最后一个隐藏层输出加权,decoderRNN的第一轮输入为句子起始符SOS。
#双向GRU的编码器,输出为最后一个隐藏层的数据 class EncoderRNN(nn.Module): def __init__(self, hidden_size, embedding, n_layers=1, dropout=0): super(EncoderRNN, self).__init__() self.n_layers = n_layers self.hidden_size = hidden_size self.embedding = embedding # Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size' # because our input size is a word embedding with number of features == hidden_size self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout), bidirectional=True) def forward(self, input_seq, input_lengths, hidden=None): # use word-embedding to preprocess input charactors embedded = self.embedding(input_seq) # 转化为变长的padding packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths) outputs, hidden = self.gru(packed, hidden) # Unpack padding outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs) # 双向RNN输出直接做和作为输出 outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] return outputs, hidden # Luong attention layer class Attn(nn.Module): def __init__(self, method, hidden_size): super(Attn, self).__init__() self.method = method if self.method not in ['dot', 'general', 'concat']: raise ValueError(self.method, "is not an appropriate attention method.") self.hidden_size = hidden_size if self.method == 'general': self.attn = nn.Linear(self.hidden_size, hidden_size) elif self.method == 'concat': self.attn = nn.Linear(self.hidden_size * 2, hidden_size) self.v = nn.Parameter(torch.FloatTensor(hidden_size)) def dot_score(self, hidden, encoder_output): return torch.sum(hidden * encoder_output, dim=2) def general_score(self, hidden, encoder_output): energy = self.attn(encoder_output) return torch.sum(hidden * energy, dim=2) def concat_score(self, hidden, encoder_output): energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh() return torch.sum(self.v * energy, dim=2) def forward(self, hidden, encoder_outputs): # Calculate the attention weights (energies) based on the given method if self.method == 'general': attn_energies = self.general_score(hidden, encoder_outputs) elif self.method == 'concat': attn_energies = self.concat_score(hidden, encoder_outputs) elif self.method == 'dot': attn_energies = self.dot_score(hidden, encoder_outputs) # Transpose max_length and batch_size dimensions attn_energies = attn_energies.t() # Return the softmax normalized probability scores (with added dimension) return F.softmax(attn_energies, dim=1).unsqueeze(1) #使用Luong Attention的Decoder class LuongAttnDecoderRNN(nn.Module): def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1): super(LuongAttnDecoderRNN, self).__init__() # Keep for reference self.attn_model = attn_model self.hidden_size = hidden_size self.output_size = output_size self.n_layers = n_layers self.dropout = dropout # Define layers self.embedding = embedding self.embedding_dropout = nn.Dropout(dropout) self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout)) self.concat = nn.Linear(hidden_size * 2, hidden_size) self.out = nn.Linear(hidden_size, output_size) self.attn = Attn(attn_model, hidden_size) def forward(self, input_step, last_hidden, encoder_outputs): # Note: we run this one step (word) at a time # embedding SOS embedded = self.embedding(input_step) embedded = self.embedding_dropout(embedded) # Forward through unidirectional GRU rnn_output, hidden = self.gru(embedded, last_hidden) # 计算Attention Weight attn_weights = self.attn(rnn_output, encoder_outputs) # 计算encoder output基于Attention Weight的加权和 context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # 合并encoder output和GRU第一轮的输出 rnn_output = rnn_output.squeeze(0) context = context.squeeze(1) concat_input = torch.cat((rnn_output, context), 1) concat_output = torch.tanh(self.concat(concat_input)) # 将word embedding 转化回字符 output = self.out(concat_output) output = F.softmax(output, dim=1) # Return output and final hidden state return output, hidden
RNN由于具有时序性,所以无法在GPU上很好的加速,因此迭代次数有限,Model文件夹为迭代29epoch后的模型。
以下对联为CharRNN的输出结果(由于每轮起始是GRU中的Memory为随机的,输出也具有随机性):
上联:<s>天<\s>
下联:<s>地<\s>
上联:<s>雨<\s>
下联:<s>烟<\s>
上联:<s>米饭<\s>
下联:<s>油茶<\s>
上联:<s>山花<\s>
下联:<s>野禽<\s>
上联:<s>鸡冠花<\s>
下联:<s>龙牙梨<\s>
上联:<s>孔夫子<\s>
下联:<s>毛小公<\s>
上联:<s>今天打雷下雨<\s>
下联:<s>昨日打人走人<\s>
上联:<s>狗和猫打架不分胜负<\s>
下联:<s>狼与狗进球就是高多<\s>
文字越多输出的连贯性越差,并且可能出现如下字数不相符的情况:
上联:<s>人生没有彩排,每一天都是现场直播<\s>
下联:<s>世海无多解势,众今岂来地网先争<\s>
个人理解是如果训练次数足够多可以获得更好的结果。
完整代码见github
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。