赞
踩
转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]
如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~
目录
以下内容若有错误,欢迎指出!
我们现在知道以下几个知识点:
好了,现在有一个问题,如果结合索引与copy_操作,那是否会复制成功?
答案是,不会成功。我们可以用代码测试一下:
- import torch
-
- print('>> 使用=号直接复制 <<')
- buff = torch.arange(5)
- mask = [True, False, False, True, False]
- print('输出原始变量:', buff)
- print('输出索引掩码:', mask)
- print('输出索引变量:', buff[mask])
- buff[mask][0] = 10
- print('索引变量修改:', buff)
- buff[mask] = torch.tensor([8, 9])
- print('索引变量赋值:', buff)
-
- print('*' * 50)
-
- buff = torch.arange(5)
- print('输出原始变量:', buff)
- print('输出切片索引:', '1:3')
- buff_indices = buff[1:3]
- print('输出切片变量:', buff[buff_indices])
- buff[1:3][0] = 10
- print('切片变量修改:', buff)
- buff[1:3] = torch.tensor([8, 9])
- print('切片变量赋值:', buff)
-
- print('=' * 50)
-
- print('>> 使用copy_原地复制 <<')
-
- buff = torch.arange(5)
- mask = [True, False, False, True, False]
- print('输出原始变量:', buff)
- print('输出索引掩码:', mask)
- print('输出索引变量:', buff[mask])
- buff[mask].copy_(torch.tensor([8, 9]))
- print('索引变量copy:', buff)
-
- print('*' * 50)
-
- buff = torch.arange(5)
- print('输出原始变量:', buff)
- print('输出切片索引:', '1:3')
- print('输出切片变量:', buff[1:3])
- buff[1:3].copy_(torch.tensor([8, 9]))
- print('切片变量copy:', buff)

输出结果(改变的地方加粗了):
>> 使用=号直接复制 <<
输出原始变量: tensor([0, 1, 2, 3, 4])
输出索引掩码: [True, False, False, True, False]
输出索引变量: tensor([0, 3])
索引变量修改: tensor([0, 1, 2, 3, 4])
索引变量赋值: tensor([8, 1, 2, 9, 4])
**************************************************
输出原始变量: tensor([0, 1, 2, 3, 4])
输出切片索引: 1:3
输出切片变量: tensor([1, 2])
切片变量修改: tensor([ 0, 10, 2, 3, 4])
切片变量赋值: tensor([0, 8, 9, 3, 4])
==================================================
>> 使用copy_原地复制 <<
输出原始变量: tensor([0, 1, 2, 3, 4])
输出索引掩码: [True, False, False, True, False]
输出索引变量: tensor([0, 3])
索引变量赋值: tensor([0, 1, 2, 3, 4])
**************************************************
输出原始变量: tensor([0, 1, 2, 3, 4])
输出切片索引: 1:3
输出切片变量: tensor([1, 2])
切片变量赋值: tensor([0, 8, 9, 3, 4])
在PyTorch中,当你使用布尔掩码或索引来访问张量时,通常会创建一个新的张量,而不是对原始张量进行原地修改。在PyTorch中,切片操作通常会返回一个视图,而不是数据的副本。这意味着切片操作返回的张量和原始张量共享相同的内存。因此,对切片后的张量进行的任何修改都会影响到原始张量。与此相对,布尔掩码索引返回的是数据的副本,因此修改索引得到的张量不会影响原始张量。
因此可见,由于索引返回的是新张量,而copy_是原地复制,因此对于原来的变量来说并没有影响,所以不会复制成功。
而=号这个赋值操作,不管是基本索引还是高级索引,由于底层都是对张量的原地操作,因此确实可以赋值成功。
根据以上内容就知道,有时候我们如果这样用,那就是错的:
- buff = torch.arange(5)
- mask = [True, False, False, True, False]
- buff[mask].copy_(torch.tensor([8, 9]))
如果确实想结合索引和copy_一起用怎么办?那么可以试试masked_scatter_。
:
来指定范围。对于背景知识里的第4点,我们也来通过代码验证一下。
基本索引包括标量索引、切片操作和整数索引。PyTorch通常会返回原始张量的视图,这意味着它们共享相同的底层数据。因此,对视图的修改会影响原始张量。例如:
- import torch
-
- a = torch.tensor([1, 2, 3, 4])
- b = a[:2] # 基本索引,b 是 a 的视图
- b[0] = 10 # 修改视图会影响原始张量
- print(a) # 输出: tensor([10, 2, 3, 4])
高级索引包括使用布尔数组、整数数组或多维索引。PyTorch和NumPy一样,高级索引会返回一个新的张量,即副本,不与原始数据共享内存。因此,对副本的修改不会影响原始张量。例如:
- import torch
-
- a = torch.tensor([1, 2, 3, 4])
- indices = torch.tensor([0, 2])
- b = a[indices] # 高级索引,b 是 a 的副本
- b[0] = 10 # 修改副本不会影响原始张量
- print(a) # 输出: tensor([1, 2, 3, 4])
- print(b) # 输出: tensor([10, 3])
无论是通过基本索引还是高级索引,赋值操作都是原地操作,这意味着它们会直接修改原始张量的内容。例如:
基本索引赋值:
- a = torch.tensor([1, 2, 3, 4])
- a[:2] = torch.tensor([10, 20]) # 原地修改 a
- print(a) # 输出: tensor([10, 20, 3, 4])
高级索引赋值:
- a = torch.tensor([1, 2, 3, 4])
- indices = torch.tensor([0, 2])
- a[indices] = torch.tensor([10, 20]) # 原地修改 a
- print(a) # 输出: tensor([10, 2, 20, 4])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。