赞
踩
import os from PIL import Image from torch.utils.data import Dataset # dataset有两个作用:1、加载每一个数据,并获取其label;2、用len()查看数据集的长度 class MyData(Dataset): def __init__(self, root_dir, label_dir): # 初始化,为这个函数用来设置在类中的全局变量 self.root_dir = root_dir self.label_dir = label_dir self.path = os.path.join(self.root_dir,self.label_dir) # 单纯的连接起来而已,背下来怎么用就好了,因为在win下和linux下的斜线方向不一样,所以用这个函数来连接路径 self.img_path = os.listdir(self.path) # img_path 的返回值,就已经是一个列表了 def __getitem__(self, idx): # 获取数据对应的 label img_name = self.img_path[idx] # img_name 在上一个函数的最后,返回就是一个列表了 img_item_path = os.path.join(self.root_dir, self.label_dir, img_name) # 这行的返回,是一个图片的路径,加上图片的名称了,能够直接定位到某一张图片了 img = Image.open(img_item_path) # 这个步骤看来是不可缺少的,要想 show 或者 操作图片之前,必须要把图片打开(读取),也就是 Image.open()一下,这可能是 PIL 这个类型图片的特有操作 label = self.label_dir # 这个例子中,比较特殊,因为图片的 label 值,就是图片所在上一级的目录 return img, label # img 是每一张图片的名称,根据这个名称,就可以使用查看(直接img)、print、size等功能 # label 是这个图片的标签,在当前这个类中,标签,就是只文件夹名称,因为我们就是这样定义的 def __len__(self): return len(self.img_path) # img_path,已经是一个列表了,len()就是在对这个列表进行一些操作 if __name__ == '__main__': root_dir = "F:\\PhD\\01-Python_In_One\\Project\\【B_up】XiaoTuDui\\data\\train" # root_dir = "data/train" ants_label_dir = "ants_image" bees_label_dir = "bees_image" ants_dataset = MyData(root_dir, ants_label_dir) bees_dataset = MyData(root_dir, bees_label_dir) train_dataset = ants_dataset + bees_dataset
# from torch.utils.tensorboard import SummaryWriter # !usr/bin/env python3 # -*- coding:utf-8 -*- """ author :24nemo date :2021年07月12日 """ ''' Dataset: 能把数据进行编号 提供一种方式,获取数据,及其label,实现两个功能: 1、如何获取每一个数据,及其label 2、告诉我们总共有多少个数据 数据集的组织形式,有两种方式: 1、文件夹的名字,就是数据的label 2、文件名和label,分别处在两个文件夹中,label可以用txt的格式进行存储 在jupyter中,可以查看,help,两个方式: 1、help(Dataset) 2、Dataset?? Dataloader: 为网络提供不同的数据形式,比如将0、1、2、3进行打包 这一节内容很重要 ''' ''' # writer = SummaryWriter("logs") class MyData(Dataset): def __init__(self, root_dir, image_dir, label_dir, transform): # 初始化,为这个函数用来设置在类中的全局变量 self.root_dir = root_dir self.image_dir = image_dir self.label_dir = label_dir self.label_path = os.path.join(self.root_dir, self.label_dir) self.image_path = os.path.join(self.root_dir, self.image_dir) self.image_list = os.listdir(self.image_path) self.label_list = os.listdir(self.label_path) self.transform = transform # 因为 label 和 Image文件名相同,进行一样的排序,可以保证取出的数据和label是一一对应的 self.image_list.sort() self.label_list.sort() def __getitem__(self, idx): img_name = self.image_list[idx] label_name = self.label_list[idx] img_item_path = os.path.join(self.root_dir, self.image_dir, img_name) label_item_path = os.path.join(self.root_dir, self.label_dir, label_name) img = Image.open(img_item_path) with open(label_item_path, 'r') as f: label = f.readline() # img = np.array(img) img = self.transform(img) sample = {'img': img, 'label': label} return sample def __len__(self): # assert len(self.image_list) == len(self.label_list) return len(self.image_list) if __name__ == '__main__': transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]) root_dir = "dataset/train" image_ants = "ants_image" label_ants = "ants_label" ants_dataset = MyData(root_dir, image_ants, label_ants, transform) image_bees = "bees_image" label_bees = "bees_label" bees_dataset = MyData(root_dir, image_bees, label_bees, transform) train_dataset = ants_dataset + bees_dataset # transforms = transforms.Compose([transforms.Resize(256, 256)]) dataloader = DataLoader(train_dataset, batch_size=1, num_workers=2) # writer.add_image('error', train_dataset[119]['img']) # writer.close() # for i, j in enumerate(dataloader): # # imgs, labels = j # print(type(j)) # print(i, j['img'].shape) # # writer.add_image("train_data_b2", make_grid(j['img']), i) # writer.close() # jupyter notebook 等方法,可以查看 help ''' ''' 以下内容是视频中完全一样的代码,截图,在 20210713 的笔记中,包括 python console 的代码也有保存 '''
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。