当前位置:   article > 正文

AIGC笔记--条件自回归Transformer的搭建

AIGC笔记--条件自回归Transformer的搭建

1--概述

        1. 自回归 TransFormer 规定Token只能看到自身及前面的Token,因此需生成一个符合规定的Attention Mask;(代码提供了两种方式自回归Attention Mask的定义方式);

        2. 使用Cross Attention实现条件模态和输入模态之间的模态融合,输入模态作为Query,条件模态作为Key和Value;

2--代码

  1. import torch
  2. import torch.nn as nn
  3. class CrossAttention(nn.Module):
  4. def __init__(self, embed_dim: int, num_heads: int):
  5. super().__init__()
  6. self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads)
  7. def forward(self, input_x: torch.Tensor, condition: torch.Tensor, attn_mask: torch.Tensor = None):
  8. '''
  9. query: input_x
  10. key: condition
  11. val: condition
  12. '''
  13. input_x = self.cross_attn(input_x, condition, condition, attn_mask=attn_mask)[0]
  14. return input_x
  15. class Cond_Autoregressive_layer(nn.Module):
  16. def __init__(self, input_dim: int, condtion_dim: int, embed_dim: int, num_heads: int):
  17. super(Cond_Autoregressive_layer, self).__init__()
  18. self.linear1 = nn.Linear(input_dim, embed_dim)
  19. self.linear2 = nn.Linear(condtion_dim, embed_dim)
  20. self.cond_multihead_attn = CrossAttention(embed_dim = embed_dim, num_heads = num_heads)
  21. def forward(self, input_x: torch.Tensor, conditon: torch.Tensor, attention_mask1: torch.Tensor, attention_mask2: torch.Tensor):
  22. # q, k, v, attention mask, here we set key and value are both condtion
  23. y1 = self.cond_multihead_attn(self.linear1(input_x), self.linear2(conditon), attn_mask = attention_mask1)
  24. y2 = self.cond_multihead_attn(self.linear1(input_x), self.linear2(conditon), attn_mask = attention_mask2)
  25. return y1, y2
  26. if __name__ == "__main__":
  27. # set sequence len, embedding dim, multi attention head
  28. seq_length = 10
  29. input_dim = 32
  30. condtion_dim = 128
  31. embed_dim = 64
  32. num_heads = 8
  33. # init input sequence and condtion
  34. input_x = torch.randn(seq_length, 1, input_dim)
  35. condtion = torch.randn(seq_length, 1, condtion_dim)
  36. # create two attention mask (actually they have the same function)
  37. attention_mask1 = torch.triu((torch.ones((seq_length, seq_length)) == 1), diagonal=1) # bool type
  38. attention_mask2 = attention_mask1.float() # True->1 False->0
  39. attention_mask2 = attention_mask2.masked_fill(attention_mask2 == 1, float("-inf")) # Convert ones to -inf
  40. # init model
  41. AG_layer = Cond_Autoregressive_layer(input_dim, condtion_dim, embed_dim, num_heads)
  42. # forward
  43. y1, y2 = AG_layer(input_x, condtion, attention_mask1, attention_mask2)
  44. # here we demonstrate the attention_mask1 and attention_mask2 have the same function
  45. assert(y1[0].equal(y2[0]))

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/207439
推荐阅读
相关标签
  

闽ICP备14008679号