赞
踩
在PyTorch中view()
, transpose()
和 permute()
函数都是用于改变张量(Tensor)维度结构的,但它们的作用和使用场景有所不同。
import torch batch = 3 seq_size = 2 embed = 8 torch.random.manual_seed(123) x = torch.randint(25, (batch, seq_size, embed)).float() print(x) # tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.], # [24., 4., 11., 21., 16., 21., 12., 24.]], # # [[14., 1., 13., 5., 0., 16., 5., 22.], # [ 9., 2., 21., 6., 15., 1., 16., 15.]], # # [[23., 4., 4., 16., 1., 18., 0., 20.], # [ 9., 1., 1., 7., 13., 21., 12., 12.]]]) # 将后两维度张量展平,将每一词的词嵌入按行连接 z = x.view(batch, -1) print(z) # tensor([[7., 14., 2., 10., 5., 17., 11., 7., 24., 4., 11., 21., 16., 21., 12., 24.], # [14., 1., 13., 5., 0., 16., 5., 22., 9., 2., 21., 6., 15., 1., 16., 15.], # [23., 4., 4., 16., 1., 18., 0., 20., 9., 1., 1., 7., 13., 21., 12., 12.]]) # transformer中多头注意力机制常用,把最后一维词嵌入的维度进行两次切割 # 切割出来多余的那部分做为batch放在第一个维度上 y = x.view(batch * 2, -1, embed // 2) print(y) # tensor([[[ 7., 14., 2., 10.], # [ 5., 17., 11., 7.]], # # [[24., 4., 11., 21.], # [16., 21., 12., 24.]], # # [[14., 1., 13., 5.], # [ 0., 16., 5., 22.]], # # [[ 9., 2., 21., 6.], # [15., 1., 16., 15.]], # # [[23., 4., 4., 16.], # [ 1., 18., 0., 20.]], # # [[ 9., 1., 1., 7.], # [13., 21., 12., 12.]]])
import torch batch = 3 seq_size = 2 embed = 8 torch.random.manual_seed(123) x = torch.randint(25, (batch, seq_size, embed)).float() print(x) # tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.], # [24., 4., 11., 21., 16., 21., 12., 24.]], # # [[14., 1., 13., 5., 0., 16., 5., 22.], # [ 9., 2., 21., 6., 15., 1., 16., 15.]], # # [[23., 4., 4., 16., 1., 18., 0., 20.], # [ 9., 1., 1., 7., 13., 21., 12., 12.]]]) # 将后两维交换(转置),将每一词的词嵌入按列展示 z = x.transpose(1, 2) # 等价 x.transpose(2, 1) print(z) # tensor([[[ 7., 24.], # [14., 4.], # [ 2., 11.], # [10., 21.], # [ 5., 16.], # [17., 21.], # [11., 12.], # [ 7., 24.]], # # [[14., 9.], # [ 1., 2.], # [13., 21.], # [ 5., 6.], # [ 0., 15.], # [16., 1.], # [ 5., 16.], # [22., 15.]], # # [[23., 9.], # [ 4., 1.], # [ 4., 1.], # [16., 7.], # [ 1., 13.], # [18., 21.], # [ 0., 12.], # [20., 12.]]])
import torch batch = 3 seq_size = 2 embed = 8 torch.random.manual_seed(123) x = torch.randint(25, (batch, seq_size, embed)).float() print(x) # tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.], # [24., 4., 11., 21., 16., 21., 12., 24.]], # # [[14., 1., 13., 5., 0., 16., 5., 22.], # [ 9., 2., 21., 6., 15., 1., 16., 15.]], # # [[23., 4., 4., 16., 1., 18., 0., 20.], # [ 9., 1., 1., 7., 13., 21., 12., 12.]]]) # 将后两维重新排序 # 注意这样是报错x.permute(2, 1)或者permute(1, 2, 1)都是非法的 z = x.permute(0, 2, 1) # 等价 x.transpose(2, 1), # print(z) # 如果我们想要三个维度都交换transpose是做不到的 # 至于有什么实际意义就不讨论了 y = x.permute(2, 1, 0) print(y) # tensor([[[ 7., 14., 23.], # [24., 9., 9.]], # # [[14., 1., 4.], # [ 4., 2., 1.]], # # [[ 2., 13., 4.], # [11., 21., 1.]], # # [[10., 5., 16.], # [21., 6., 7.]], # # [[ 5., 0., 1.], # [16., 15., 13.]], # # [[17., 16., 18.], # [21., 1., 21.]], # # [[11., 5., 0.], # [12., 16., 12.]], # # [[ 7., 22., 20.], # [24., 15., 12.]]])
import torch
seq_size = 2
embed = 8
torch.random.manual_seed(123)
x = torch.randint(25, (seq_size, embed)).float()
print(x)
# tensor([[ 7., 14., 2., 10., 5., 17., 11., 7.],
# [24., 4., 11., 21., 16., 21., 12., 24.]])
z = x.unsqueeze(0) # 等价 torch.unsqueeze(x, dim=0)
print(z)
# tensor([[[ 7., 14., 2., 10., 5., 17., 11., 7.],
# [24., 4., 11., 21., 16., 21., 12., 24.]]])
总结:
view()
更侧重于保持数据不变的前提下改变张量的维度形状,常用于展平、重塑等操作。
transpose()
是特定的维度交换操作,只涉及两个维度的变换。
permute()
则提供了更灵活的维度重排功能,可以处理多维度情况下的整体维度顺序调整。
unsqueeze()
指定位置增加张量维度。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。