赞
踩
熟悉深度学习的小伙伴一定都知道:深度学习模型训练主要由数据、模型、损失函数、优化器以及迭代训练五个模块组成。如下图所示,Pytorch数据读取机制则是数据模块中的主要分支。
Pytorch数据读取是通过Dataset+Dataloader的方式完成。其中,
Dataset用于解决数据从哪里读取以及如何读取的问题。 Pytorch给定的Dataset是一个抽象类,所有自定义的数据集都要继承Dataset,并重写**init()、getitem()和__len__()**类方法,以供DataLoader类直接调用。
下面是笔者以cifar10数据集为例实现Dataset自定义数据集的代码样例。
from torch.utils.data import Dataset from PIL import Image import os class Mydata(Dataset): """ 步骤一:继承 torch.utils.data.Dataset 类 """ def __init__(self,data_dir,label_dir): """ 步骤二:实现 __init__ 函数,初始化数据集,将样本和标签映射到列表中 """ self.data_dir = data_dir self.label_dir = label_dir # 用join把路径拼接一起可以避免一些因“/”引发的错误 self.path = os.path.join(self.data_dir,self.label_dir) # 将该路径下的所有文件变成一个列表 self.img_path = os.listdir(self.path) def __getitem__(self,idx) """ 步骤三:实现 __getitem__ 函数,定义指定 index 时如何获取数据,并返回单条数据(样本数据、对应的标签) """ # 根据index(idx),从列表中取出图片 # img_path列表里每个元素就是对应图片文件名 img_name = self.img_path[idx] # 获得对应图片路径 img_item_path = os.path.join(self.data_dir,self.label_dir,img_name) # 使用PIL库下Image工具,打开对应路径图片 img = Image.open(img_item_path) label = self.label_dir # 返回图片和对应标签 return img,label def __len__(self): """ 步骤四:实现 __len__ 函数,返回数据集的样本总数 """ return len(self.img_path) # data_dir,label_dir可自定义数据集目录 train_custom_dataset = MyData(data_dir,label_dir) test_custom_dataset = MyData(data_dir,label_dir)
在实际项目中,当数据量很大,考虑到内存有限、I/O速度等问题,训练中不可能一次性将所有数据加载到内存或者只用一个进行加载数据,此时就需要的是多进程、迭代加载,Dataloader便应运而生。
DataLoader是一个可迭代的数据装载器,组合了数据集和采样器,并在给定数据集上提供可迭代对象。可以完成对数据集中多个对象的集成。
Pytorch的数据读取机制中DataLoader模块包括Sampler和Dataset两个子模块,其中Sampler模块生成索引index;Dataset模块是根据索引读取数据。DataLoader读取数据流程如下图所示。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UZqgYimv-1684309723395)(imgs/230424183501.png)]
Pytorch中DataLoader类定义如下:
class torch.utils.data.DataLoader(
"""
构建可迭代的数据装载器,训练时,每一个for循环,每一次迭代,
从DataLoader中获取一个batch_size大小的数据
"""
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
)
补充说明
Epoch:所有训练样本都已输入到模型中,称为一个epoch
Iteration:一批样本(batch_size)输入到模型中,称为一个Iteration。
Batchsize:一批样本的大小,称为Batchsize。用于决定一个epoch有多少个Iteration。
代码实现示例如下。
import torch import torch.utils.data as Data BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) y = torch.linspace(10, 1, 10) # 将数据集转换为torch可识别的类型 torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0 ) for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): print('epoch', epoch, '| step:', step, '| batch_x', batch_x.numpy(), '| batch_y:', batch_y.numpy())
通过上述方法即可初始化一个数据读取器loader,用于加载训练数据集torch_dataset。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。