赞
踩
class AdditiveAttention(nn.Module): def __init__(self, keys_size, queries_size, num_hiddens, dropout, **kwargs): super(AdditiveAttention, self).__init__(**kwargs) self.W_q = nn.Linear(queries_size, num_hiddens, bias=False) self.W_k = nn.Linear(keys_size, num_hiddens, bias=False) self.W_v = nn.Linear(num_hiddens, 1, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values): queries, keys = self.W_q(queries), self.W_k(keys) ''' queries --> [batch_size, queries_length, num_hiddens] keys --> [batch_size, keys_length, num_hiddens]''' features = queries.unsqueeze(2) + keys.unsqueeze(1) ''' queries.unsqueeze(2) --> [batch_size, queries_length, 1, num_hiddens] keys.unsqueeze(1) --> [batch_size, 1, keys_length, num_hiddens] features --> [batch_size, queries_length, keys_length, num_hiddens] ''' features = torch.tanh(features) scores = self.W_v(features).squeeze(-1) ''' self.W_v(features) --> [batch_size, queries_length, keys_length, 1] scores--> [batch_size, queries_length, keys_length]''' self.attention_weights = F.softmax(scores, dim=1) ''' self.attention_weights --> [batch_size, queries_length, keys_length]''' return torch.bmm(self.dropout(self.attention_weights), values) ''' output --> [batch_size, queries_length, value_features_num] ''' ############# ### 实例测试 ### ############# queries, keys = torch.normal(0, 1, (2, 2, 20)), torch.ones((2, 10, 2)) # `values` 的小批量数据集中,两个值矩阵是相同的 values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat( 2, 1, 1) attention = AdditiveAttention( keys_size=2, queries_size=20, num_hiddens=8, dropout=0.1) attention.eval() output = attention(queries, keys, values) ''' output: tensor([[[ 91.1298, 96.1926, 101.2553, 106.3181], [ 88.8702, 93.8074, 98.7447, 103.6819]], [[ 92.0438, 97.1574, 102.2709, 107.3845], [ 87.9562, 92.8426, 97.7291, 102.6155]]] shape : [2,2,4] '''
aa = torch.arange(12).reshape(1,1,4,3) '''output: tensor([[[[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]]])''' bb = torch.arange(6).reshape(1,2,1,3) '''output: tensor([[[[0, 1, 2]], [[3, 4, 5]]]])''' aa + bb '''output: tensor([[[[ 0, 2, 4], [ 3, 5, 7], [ 6, 8, 10], [ 9, 11, 13]], [[ 3, 5, 7], [ 6, 8, 10], [ 9, 11, 13], [12, 14, 16]]]])'''
class DotProductAttention(nn.Module): def __init__(self, dropout, **kwargs): super(DotProductAttention, self).__init__(**kwargs) self.dropout = nn.Dropout(dropout) def forward(self, queries, keys, values): ''' queries --> [batch_size, queries_length, queries_feature_num] keys --> [batch_size, keys_values_length, keys_features_num] values --> [barch_size, keys_values_length, values_features_num] 点积模型中: queries_features_num = keys_features_num ''' d = queries.shape[-1] '''交换keys的后两个维度,相当于公式中的转置''' scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d) self.attention_weights = F.softmax(scores, dim=1) return torch.bmm(self.dropout(self.attention_weights), values) queries = torch.normal(0, 1, (2, 1, 2)) attention = DotProductAttention(dropout=0.5) attention.eval() dot_output = attention(queries, keys, values) print(dot_output) ''' dot_output: tensor([[[180., 190., 200., 210.]], [[180., 190., 200., 210.]]]) '''
attention机制常与sequence2sequence相结合使用,相应的查询(queries)、键(keys)和值(values)分别为:
sequence2sequence with attention的基本流程如下:
sequence2sequence:包括编码层和解码层两个部分,其中attention机制加入到解码层中,先定义编码层,代码如下:
class Encoder(nn.Module): def __init__(self, inputs_dim, num_hiddens, hiddens_layers): super(Encoder, self).__init__() self.rnn1 = nn.GRU( input_size=inputs_dim, hidden_size=num_hiddens, num_layers=hiddens_layers) def forward(self, inputs): '''由于nn.GRU没有设置 batch_first=True 因此输入的维度排列:[time_step_num, batch_size, num_features] 输出维度为: output: [time_step_num, batch_size, hiddens_num] hidSta: [num_layers, batch_size, hiddens_num] ''' inputs = inputs.permute(1, 0, 2) encOut, hidSta = self.rnn1(inputs) return encOut, hidSta class AttentionDecoder(nn.Module): def __init__( self, inputs_dim, num_hiddens, num_layers, outputs_dim, dropout): super(AttentionDecoder, self).__init__() self.attention = AdditiveAttention( num_hiddens, num_hiddens, num_hiddens, dropout) self.rnn = nn.GRU( inputs_dim + num_hiddens, num_hiddens, num_layers, dropout=dropout) self.dense = nn.Linear(num_hiddens, outputs_dim) def forward(self, inputs, states): ''' inputs: [batch_size, time_step_num, features] states: enc_ouptut, enc_hidden_state ''' enc_outputs, hidden_state = states '''将enc_output的维度变为[batch_size, time_step_num, enc_hidden_num]''' enc_outputs = enc_outputs.permute(1, 0, 2) inputs = inputs.permute(1, 0, 2) '''将inputs的维度变为[time_step_num, batch_size, features_num]''' outputs, self._attention_weights = [], [] '''对每一时间步的inputs进行计算,并于上下文信息进行融合''' for x in inputs: '''提取enc_hidden最后一层的输出作为query,并在第2维添加维度 hidden_state[-1] : [batch_size, enc_hidden_num] --> [batch_size, 1, enc_hidden_num]''' query = hidden_state[-1].unsqueeze(dim=1) import pdb;pdb.set_trace() '''context: [batch_size, query_length=1, hiddens_num]''' context = self.attention(query, enc_outputs, enc_outputs) x = torch.cat((context, x.unsqueeze(dim=1)), dim=-1) '''更新hidden_state''' out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state) outputs.append(out) self._attention_weights.append(self.attention.attention_weights) outputs = self.dense(torch.cat(outputs, dim=0)) return outputs.permute(1, 0, 2), [enc_outputs, hidden_state] ########## ### 实例 ### ######### encoder = Encoder(inputs_dim=10, num_hiddens=20, hiddens_layers=2) decoder = AttentionDecoder( inputs_dim=10, num_hiddens=20, num_layers=2, outputs_dim=8, dropout=0.1) inputs = torch.normal(0, 1, (4, 8, 10)) state = encoder(inputs) dec_inputs = torch.normal(0, 1, (4, 1, 10)) dec_output, state = decoder(dec_inputs, state) print(dec_output.shape) ''' output: [4, 1, 8] '''
class EncoderDecoder(nn.Module):
"""The base class for the encoder-decoder architecture."""
def __init__(self, encoder, decoder, **kwargs):
super(EncoderDecoder, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
state = self.encoder(enc_X, *args)
dec_state = self.decoder(dec_X, state)
return dec_state
net = EncoderDecoder(encoder, decoder)
output = net(inputs, dec_inputs)
print(output[0].shape) # -->[4,1,8]
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。