当前位置:   article > 正文

[pytorch] 定义自己的dataloader

[pytorch] 定义自己的dataloader

在使用自己数据集训练网络时,往往需要定义自己的dataloader。

1 定义datalaoder

一般将dataloader封装为一个类,这个类继承自 torch.utils.data.dataset

from torch.utils.data import dataset

class LoadData(Dataset):  # 注意父类的名称,不能写dataset
    pass
  • 1
  • 2
  • 3
  • 4

需要注意的是dataset是模块名,而Dataset是类名,在python中模块名和类名是完全独立的命名空间,因此这里的父类需要写成 dataset.Dataset。

在我们定义的LoadData中,至少需要有三个方法:

  • __init__方法,主要用来定义数据的预处理
  • __getitem__方法,返回数据的item和label
  • __len__方法,返回数据个数

整体大致架构:

import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

class LoadData(dDataset):

    def __init__(self):
        pass

    def __getitem__(self,index):
        pass

    def __len__(self):
        pass

dataset = Loaddata()
train_loader = DataLoader(dataset = dataset,batch_size = 32,shuffle = Ture,num_workers=2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

1.1 init

__init__方法需要传入至少两个参数:

  • 一般数据的地址和标签已经被保存在某个文档中了(这里是txt格式的文档)。因此需要传入这个文档的地址。
  • 因为__init__方法要做预处理,一般用来train的预处理和test的预处理是不同的,因此需要区分二者的参数。
def __init__(self, txt_path, train=True):
        super(LoadData, self).__init__()
        self.img_info = self.get_img(txt_path)
        self.train = train

        # train预处理
        self.train_transforms = transforms.Compose([
            transforms.Resize(20),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

        # test预处理
        self.test_transforms = transforms.Compose([
            transforms.Resize(20),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    # 这个函数是用来读txt文档的
    def get_img(self, txt_path):
        with open(txt_path, 'r', encoding='utf-8') as f:
            imgs_info = f.readlines()
            imgs_info = list(map(lambda x:x.strip().split('\t'), imgs_info))
            return imgs_info
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

1.2 getitem

__getitem__方法只需要根据index返回数据的item和label。

def __getitem__(self, index):
        img_path, label = self.img_info[index]
        img = Image.open(img_path)
        label = int(label)

        # 注意区分预处理
        if self.train:
            img = self.train_transforms(img)
        else:
            img = self.test_transforms(img)

        return img, label
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

1.3 len

__len__方法最简单,仅返回数据项个数。

def __len__(self):
        return len(self.img_info)
  • 1
  • 2

2 调用dataloader

以训练数据为例,调用dataloader需要两步:

  • 将自定义的LoadData实例化
  • 传入torch.utils.data.dataloader中
from torch.utils.data import Dataloader

train_dataset = LoadData(txt_path='XXXX', train=True)

train_loader = dataloader.Dataloader(
    dataset=train_dataset,
    batch_size=8,
    shuffle=True
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

至此,一个最简单的dataloader就完成了!
可以用以下代码测试:

for image, label in train_loader:
    print(image.shape)
    print(label)
  • 1
  • 2
  • 3

参考

https://zhuanlan.zhihu.com/p/399447239

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/47404
推荐阅读
相关标签
  

闽ICP备14008679号