当前位置:   article > 正文

Pytorch:torch.utils.data.random_split()_torch data split

torch data split

random_split() 函数说明:

torch.utils.data.random_split(dataset, lengths, generator=<torch._C.Generator object>)

参数:

  • dataset(Dataset) -要拆分的数据集
  • lengths(序列) -要产生的分割长度
  • generator(torch.Generator) -用于随机排列的生成器。

注:关于torch.Generator详见笔记:Pytorch:torch.Generator()

pytorch: random_split(),函数的具体定义如下:

def random_split(dataset, lengths):
    r"""
    Randomly split a dataset into non-overlapping new datasets of given lengths.

    Arguments:
        dataset (Dataset): Dataset to be split
        lengths (sequence): lengths of splits to be produced
    """
    if sum(lengths) != len(dataset):
        raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

    indices = randperm(sum(lengths)).tolist()
    return [Subset(dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(lengths), lengths)]

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
'
运行

U-Net代码(详见:U-Net代码复现)为例:

n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
  • 1
  • 2
  • 3

通过random_split()将数据分为训练集和验证集(随机)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/878201?site
推荐阅读
相关标签
  

闽ICP备14008679号