当前位置:   article > 正文

【PyTorch教程】P6-P7 数据加载_p6资源加载教程

p6资源加载教程

完整目录

关于 Dataset 数据加载

  • 如何理解数据的加载呢?
  • 就是说,要想获取自己电脑里的数据,读取它,那么就要遵守 pytorch 加载数据的规则。
    他的规则就是 定义一个class类,继承 Dataset (from torch.utils.data import Dataset),并且,在类中,定义三个函数,分别是:初始化 init、获得每一个数据 getitem、数据长度 len。
  • 这里面的过程,要很清楚:
    1、路径、合并路径、把文件夹中的每一个文件名称,做成一个列表(这是init要做的事情);
    2、访问init中的列表,把列表的名称逐一传递给一个变量,命名为name,再次合并路径,并且把文件名连接在路径之后,接下来,用PIL中的Image.open函数,读取(加载)上述路径的文件(命名为img)(这里肯定是图像了),返回 图像img和标签 label(这是getitem的工作);
    3、最后用len()返回列表的长度。
  • 定义好 以后,后面就可以实例化这个类,定义参数(本例其实是一个路径,一个夹名称了),名称可以和定义类中的不一样,但是位置要对应(奥,这可能是Python课程里说的位置参数?)。
  • 引用之前定义的类,把上述参数,传递进去。
  • 最后打印自定义数据列表的长度。

可运行的代码

import os

from PIL import Image
from torch.utils.data import Dataset


# dataset有两个作用:1、加载每一个数据,并获取其label;2、用len()查看数据集的长度
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 初始化,为这个函数用来设置在类中的全局变量
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)  # 单纯的连接起来而已,背下来怎么用就好了,因为在win下和linux下的斜线方向不一样,所以用这个函数来连接路径
        self.img_path = os.listdir(self.path)  # img_path 的返回值,就已经是一个列表了

    def __getitem__(self, idx):  # 获取数据对应的 label
        img_name = self.img_path[idx]  # img_name 在上一个函数的最后,返回就是一个列表了
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 这行的返回,是一个图片的路径,加上图片的名称了,能够直接定位到某一张图片了
        img = Image.open(img_item_path)  # 这个步骤看来是不可缺少的,要想 show 或者 操作图片之前,必须要把图片打开(读取),也就是 Image.open()一下,这可能是 PIL 这个类型图片的特有操作
        label = self.label_dir  # 这个例子中,比较特殊,因为图片的 label 值,就是图片所在上一级的目录
        return img, label  # img 是每一张图片的名称,根据这个名称,就可以使用查看(直接img)、print、size等功能
        # label 是这个图片的标签,在当前这个类中,标签,就是只文件夹名称,因为我们就是这样定义的

    def __len__(self):
        return len(self.img_path)  # img_path,已经是一个列表了,len()就是在对这个列表进行一些操作


if __name__ == '__main__':
    root_dir = "F:\\PhD\\01-Python_In_One\\Project\\【B_up】XiaoTuDui\\data\\train"
    # root_dir = "data/train"
    ants_label_dir = "ants_image"
    bees_label_dir = "bees_image"
    ants_dataset = MyData(root_dir, ants_label_dir)
    bees_dataset = MyData(root_dir, bees_label_dir)
    train_dataset = ants_dataset + bees_dataset
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34

完整代码 P6-7_read_data.py

# from torch.utils.tensorboard import SummaryWriter

# !usr/bin/env python3
# -*- coding:utf-8 -*-

"""
author :24nemo
 date  :2021年07月12日
"""

'''
Dataset:
能把数据进行编号
提供一种方式,获取数据,及其label,实现两个功能:
1、如何获取每一个数据,及其label
2、告诉我们总共有多少个数据

数据集的组织形式,有两种方式:
1、文件夹的名字,就是数据的label
2、文件名和label,分别处在两个文件夹中,label可以用txt的格式进行存储

在jupyter中,可以查看,help,两个方式:
1、help(Dataset)
2、Dataset??

Dataloader:
为网络提供不同的数据形式,比如将0、1、2、3进行打包

这一节内容很重要
'''
'''
# writer = SummaryWriter("logs")

class MyData(Dataset):
    def __init__(self, root_dir, image_dir, label_dir, transform):
        #  初始化,为这个函数用来设置在类中的全局变量
        self.root_dir = root_dir
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.label_path = os.path.join(self.root_dir, self.label_dir)
        self.image_path = os.path.join(self.root_dir, self.image_dir)
        self.image_list = os.listdir(self.image_path)
        self.label_list = os.listdir(self.label_path)
        self.transform = transform
        # 因为 label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的
        self.image_list.sort()
        self.label_list.sort()

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        label_name = self.label_list[idx]
        img_item_path = os.path.join(self.root_dir, self.image_dir, img_name)
        label_item_path = os.path.join(self.root_dir, self.label_dir, label_name)
        img = Image.open(img_item_path)

        with open(label_item_path, 'r') as f:
            label = f.readline()

        # img = np.array(img)
        img = self.transform(img)
        sample = {'img': img, 'label': label}
        return sample

    def __len__(self):
        # assert len(self.image_list) == len(self.label_list)
        return len(self.image_list)


if __name__ == '__main__':
    transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
    root_dir = "dataset/train"
    image_ants = "ants_image"
    label_ants = "ants_label"
    ants_dataset = MyData(root_dir, image_ants, label_ants, transform)
    image_bees = "bees_image"
    label_bees = "bees_label"
    bees_dataset = MyData(root_dir, image_bees, label_bees, transform)
    train_dataset = ants_dataset + bees_dataset

    # transforms = transforms.Compose([transforms.Resize(256, 256)])
    dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2)

    # writer.add_image('error', train_dataset[119]['img'])
    # writer.close()
    # for i, j in enumerate(dataloader):
    #     # imgs, labels = j
    #     print(type(j))
    #     print(i, j['img'].shape)
    #     # writer.add_image("train_data_b2", make_grid(j['img']), i)
    
    # writer.close()

#  jupyter notebook 等方法,可以查看 help
'''

'''
以下内容是视频中完全一样的代码,截图,在 20210713 的笔记中,包括 python console 的代码也有保存
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98

运行结果

在这里插入图片描述

完整目录

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/351052
推荐阅读
相关标签
  

闽ICP备14008679号