赞
踩
深度学习初入门小白,技艺不精,写下笔记记录自己的学习过程。欢迎评论区交流提问,力所能及之问题,定当毫无保留之相授。
Dataset:是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中。
Dataloader:通过DataLoader这个函数,我们在加载数据集时候,批次读取数据及多线程并行处理,这样可以加快我们读取数据集的速度。
Dataset类是Pytorch中数据集加载类中应该继承的父类。通常包括这三部分:
1.*def __init__(self)*
2.*def __getitem__(self, index):*
3.*def __len__(self):*
其中父类中的两个私有成员函数,__len__和__getitem__必须被重载!
#root1和root2分别为训练集,验证集存放图片路径及标签的txt路径 root1 = r"C:\Users\asus\Desktop\mstar_classification\mstar\train.txt" root2 = r"C:\Users\asus\Desktop\mstar_classification\mstar\val.txt" # 1、构建数据集类 class Mydata(Dataset): # __init__ # 该函数可以包含多个参数,如数据的读取路径和对数据的处理设置等一系列设定 # txt:存放着图片数据的路径和标签信息,words[0]为图片的路径,words[1]为图片的标签,如下图所示。(txt需要事先生成,如何生成先挖个坑) # imgs:按行读取txt,并依次存放到列表中 # transform为:图片数据增强,下文中会讲 def __init__(self, txt, transform=None, target_transform=None): super(Mydata, self).__init__() imgs = [] fh = open(txt, 'r') for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) # imgs中包含有图像路径和标签 self.txt = txt self.imgs = imgs self.transform = transform self.target_transform = target_transform # __getitem__ # 接收一个index,然后返回图片路径和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息。 # 在本代码中,这个list为imgs[] # 图片打开方式为Image.open,三通道RGB格式。若数据集图片为单通道,可在transform中添加transforms.Grayscale(1)函数。 def __getitem__(self, index): fn, label = self.imgs[index] img = Image.open(os.path.join(self.txt[:-4], fn))#self.txt[:-4],下文加载txt时,路径中不需要有后缀,所以去掉.txt四个字符 if self.transform is not None: img = self.transform(img) return img, label #__len__ #返回样本的总数量, 该方法提供了dataset的大小 def __len__(self): return len(self.imgs) train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ColorJitter(), transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) test_transform = transforms.Compose([transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_data = Mydata(txt=root1, transform=train_transform) test_data = Mydata(txt=root2, transform=test_transform)
txt中存放着图片的路径及标签
该函数的作用是将数据整理成一个batch,即根据batch_size的大小一次性在数据集中取出batch_size个数据。例如数据集中有100条数据,batch_size的值为20,则每次在100条数据中取出20条数据。
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
# dataset: 加载torch.utils.data.Dataset对象数据,即为上文中的train_data和test_data
# batch_size: 每个batch的大小
# shuffle:是否对数据进行打乱
# drop_last:是否对无法整除的最后一个datasize进行丢弃
# um_workers:表示加载的时候子进程数,一般GPU使用
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
examples = enumerate(train_loader)
batch_idx, (examples_data, examples_targets) = next(examples)
fig = plt.figure()
for i in range(6):
plt.subplot(2, 3, i+1)
plt.tight_layout()#自动调整子图参数,使之填充满整个图像区域
plt.imshow(examples_data[i][0], interpolation='none')
plt.title("Category:{}".format(examples_targets[i]))
plt.xticks([])
plt.yticks([])
plt.show()
注意:
1.数据集的路径需要改成自己的
2.前提需要生成相应的txt文件
import torchvision.transforms as transforms from PIL import Image from torch.utils.data import Dataset, DataLoader import matplotlib.pyplot as plt import os root1 = r"C:\Users\asus\Desktop\mstar_classification\mstar\train.txt" root2 = r"C:\Users\asus\Desktop\mstar_classification\mstar\val.txt" # 1、构建数据集 class Mydata(Dataset): def __init__(self, txt, transform=None, target_transform=None): super(Mydata, self).__init__() self.txt = txt fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0], int(words[1]))) # imgs中包含有图像路径和标签 self.imgs = imgs self.transform = transform self.target_transform = target_transform def __getitem__(self, index): fn, label = self.imgs[index] img = Image.open(os.path.join(self.txt[:-4], fn)) if self.transform is not None: img = self.transform(img) return img, label def __len__(self): return len(self.imgs) # 2.数据增强、加载数据 train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ColorJitter(), transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) test_transform = transforms.Compose( [transforms.Grayscale(1), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) # 是被封装进DataLoader里,实现该方法封装自己的数据和标签 train_data = Mydata(txt=root1, transform=train_transform) test_data = Mydata(txt=root2, transform=test_transform) # DataLoader被封装入DataLoader里,实现该方法达到数据的划分 # train_data 和test_data包含多有的训练与测试数据,调用DataLoader批量加载 train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True) test_loader = DataLoader(dataset=test_data, batch_size=64) # 3.可视化源数据 examples = enumerate(train_loader) batch_idx, (examples_data, examples_targets) = next(examples) fig = plt.figure() for i in range(6): plt.subplot(2, 3, i + 1) plt.tight_layout() # 自动调整子图参数,使之填充满整个图像区域 plt.imshow(examples_data[i][0], interpolation='none') plt.title("Category:{}".format(examples_targets[i])) plt.xticks([]) plt.yticks([]) plt.show()
https://blog.csdn.net/sinat_42239797/article/details/90641659
https://blog.csdn.net/ChaoFeiLi/article/details/109764566
https://blog.csdn.net/l8947943/article/details/103733473
https://blog.csdn.net/kahuifu/article/details/108654421
https://blog.csdn.net/wangkaidehao/article/details/104209685
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。