当前位置:   article > 正文

Pytorch:torch.stack 和 torch.as_tensor

Pytorch:torch.stack 和 torch.as_tensor

torch.stacktorch.as_tensorPyTorch 中的两个函数,它们用于处理 tensor 的创建和操作,但它们各自的用途和功能是不同的。

torch.stack

torch.stack:这个函数用于将一系列的 tensors 沿着一个新的维度合并所有 tensors 必须有相同的形状

以下是 torch.stack 的使用例子:

i

mport torch

# 假设我们有三个相同形状的 tensors
t1 = torch.tensor([1, 2])
t2 = torch.tensor([3, 4])
t3 = torch.tensor([5, 6])

# 我们可以使用 torch.stack 将它们堆叠起来
# 这里 dim 参数表示新增维度的索引
stacked = torch.stack((t1, t2, t3), dim=0)

print(stacked)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

输出结果将是一个新的3D tensor

tensor([[1, 2],
        [3, 4],
        [5, 6]])
  • 1
  • 2
  • 3

在这个例子中,在堆叠后的tensor中增加了一个新的维度,且这个新的维度是放在原有维度的前面(dim=0)。如果你将 dim 设置为1,那么新的维度将会是在原有维度的后面。

torch.as_tensor

torch.as_tensor:这个函数将一个数据转换为 tensor。如果数据已经是一个 tensor 并且默认的数据类型和设备都与输入数据相同,则不会进行复制。数据可以是一个列表,元组,NumPy ndarray,标量以及其他类型的 tensor。

以下是 torch.as_tensor 的使用例子:

import numpy as np
import torch

# 假设我们有一个 numpy ndarray
array = np.array([7, 8, 9])

# 我们可以使用 torch.as_tensor 将其转为 tensor
tensor = torch.as_tensor(array)

print(tensor)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

输出结果:

tensor([7, 8, 9], dtype=torch.int32)
  • 1

在这个例子中,as_tensor没有复制数据,而是使用了原始的内存空间。因此,如果更改array的值,tensor里的值也会相应更改,反之亦然。如果需要避免此行为,应该使用torch.tensor,它总是进行数据的复制。

简而言之,torch.stack 是用于在新的维度上合并 tensors 的函数, torch.as_tensor 主要是用于将其他数据格式转换为 tensor,可能不会进行数据的复制。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/352986
推荐阅读
相关标签
  

闽ICP备14008679号