当前位置:   article > 正文

利用transforms Dataset DataLoader对图像数据进行处理并构建自己的数据集_制作用于transform模型的图片数据集

制作用于transform模型的图片数据集

1. torchvision.transforms

在CV任务中,可以用此对图像进行预处理,数据增强等操作

1.1 Transforms on Image

import torchvision.transforms as transforms
from PIL import Image

img = Image.open('lena.png')
img = img.convert("RGB")
img
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在这里插入图片描述

width, height = img.size
print(width, height)
  • 1
  • 2
132 193
  • 1
1.1.1 transforms.Resize

把给定的图片resize到给定的size

size = (100, 100)
transform = transforms.Resize(size=size)
resize_img = transform(img)
resize_img
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

1.1.2 transforms.CenterCrop

在图片的中心区域进行裁剪

size = (100, 100)
transform = transforms.CenterCrop(size=size)
centercrop_img = transform(img)
centercrop_img
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

1.1.3 transforms.RandomCrop

在图片上随机一个位置进行裁剪

size = (100, 100)
transform = transforms.RandomCrop(size=size)
randomcrop_img = transform(img)
randomcrop_img
  • 1
  • 2
  • 3
  • 4

在这里插入图片描述

1.1.4 transforms.RandomHorizontalFlip§

以概率为p水平翻转给定的图像

transform = transforms.RandomHorizontalFlip(p=0.5)
rpf_img = transform(img)
rpf_img
  • 1
  • 2
  • 3

在这里插入图片描述

1.1.5 transforms.RandomVerticalFlip§

以概率为p垂直翻转给定的图像

transform = transforms.RandomVerticalFlip(p=0.5)
rvf_img = transform(img)
rvf_img
  • 1
  • 2
  • 3

在这里插入图片描述

1.1.6 transforms.ColorJitter

随机修改图片的亮度、对比度和饱和度,常用来进行数据增强

brightness = (1, 10)
contrast = (1, 10)
saturation = (1, 10)
hue = (0.2, 0.4)
transform = transforms.ColorJitter(brightness, contrast, saturation, hue)
colorjitter_img = transform(img)
colorjitter_img
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

在这里插入图片描述

1.1.7 transforms.Grayscale

将图像转换为灰度图像

transform = transforms.Grayscale()
gary_img = transform(img)
gary_img
  • 1
  • 2
  • 3

在这里插入图片描述

1.1.8 transforms.RandomGrayscale

以概率p将图像转换为灰度图像

transform = transforms.RandomGrayscale(p=0.5)
rg_img = transform(img)
rg_img
  • 1
  • 2
  • 3

在这里插入图片描述

1.2 transforms on Tensor

1.2.1 transforms.ToTensor()

将Image转换为Tensor

transform = transforms.ToTensor()
tensor_img = transform(img)
tensor_img
  • 1
  • 2
  • 3
tensor([[[0.7176, 0.7294, 0.7255,  ..., 0.6627, 0.6549, 0.6627],
         [0.7137, 0.7176, 0.7176,  ..., 0.6510, 0.6510, 0.6549],
         [0.7137, 0.7176, 0.7137,  ..., 0.6392, 0.6431, 0.6353],
         ...,
         [0.9922, 1.0000, 0.9725,  ..., 0.6863, 0.6902, 0.7059],
         [1.0000, 1.0000, 0.9961,  ..., 0.6745, 0.6824, 0.6902],
         [1.0000, 0.9961, 0.9882,  ..., 0.6745, 0.6745, 0.6863]],

        [[0.3843, 0.3922, 0.3922,  ..., 0.3529, 0.3451, 0.3529],
         [0.3765, 0.3804, 0.3804,  ..., 0.3412, 0.3412, 0.3412],
         [0.3765, 0.3804, 0.3804,  ..., 0.3294, 0.3412, 0.3333],
         ...,
         [0.8745, 0.8941, 0.8863,  ..., 0.3294, 0.3490, 0.3647],
         [0.9098, 0.9176, 0.9176,  ..., 0.3216, 0.3373, 0.3490],
         [0.9294, 0.9255, 0.9255,  ..., 0.3216, 0.3294, 0.3412]],

        [[0.2745, 0.2863, 0.2784,  ..., 0.2353, 0.2235, 0.2353],
         [0.2784, 0.2745, 0.2745,  ..., 0.2353, 0.2353, 0.2314],
         [0.2784, 0.2745, 0.2706,  ..., 0.2275, 0.2392, 0.2353],
         ...,
         [0.8706, 0.8824, 0.8627,  ..., 0.2510, 0.2706, 0.2863],
         [0.9216, 0.9176, 0.9059,  ..., 0.2392, 0.2588, 0.2706],
         [0.9451, 0.9333, 0.9255,  ..., 0.2392, 0.2510, 0.2588]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
1.2.2 transforms.Normalize

input[channel] = (input[channel] - mean[channel]) / std[channel]

transform = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
img_normal = transform(tensor_img)
img_normal
  • 1
  • 2
  • 3
tensor([[[ 0.4353,  0.4588,  0.4510,  ...,  0.3255,  0.3098,  0.3255],
         [ 0.4275,  0.4353,  0.4353,  ...,  0.3020,  0.3020,  0.3098],
         [ 0.4275,  0.4353,  0.4275,  ...,  0.2784,  0.2863,  0.2706],
         ...,
         [ 0.9843,  1.0000,  0.9451,  ...,  0.3725,  0.3804,  0.4118],
         [ 1.0000,  1.0000,  0.9922,  ...,  0.3490,  0.3647,  0.3804],
         [ 1.0000,  0.9922,  0.9765,  ...,  0.3490,  0.3490,  0.3725]],

        [[-0.2314, -0.2157, -0.2157,  ..., -0.2941, -0.3098, -0.2941],
         [-0.2471, -0.2392, -0.2392,  ..., -0.3176, -0.3176, -0.3176],
         [-0.2471, -0.2392, -0.2392,  ..., -0.3412, -0.3176, -0.3333],
         ...,
         [ 0.7490,  0.7882,  0.7725,  ..., -0.3412, -0.3020, -0.2706],
         [ 0.8196,  0.8353,  0.8353,  ..., -0.3569, -0.3255, -0.3020],
         [ 0.8588,  0.8510,  0.8510,  ..., -0.3569, -0.3412, -0.3176]],

        [[-0.4510, -0.4275, -0.4431,  ..., -0.5294, -0.5529, -0.5294],
         [-0.4431, -0.4510, -0.4510,  ..., -0.5294, -0.5294, -0.5373],
         [-0.4431, -0.4510, -0.4588,  ..., -0.5451, -0.5216, -0.5294],
         ...,
         [ 0.7412,  0.7647,  0.7255,  ..., -0.4980, -0.4588, -0.4275],
         [ 0.8431,  0.8353,  0.8118,  ..., -0.5216, -0.4824, -0.4588],
         [ 0.8902,  0.8667,  0.8510,  ..., -0.5216, -0.4980, -0.4824]]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
1.2.3 transforms.Compose

将多个变换组合在一起

img = Image.open('lena.png')
img = img.convert('RGB')

transform = transforms.Compose([
    transforms.Resize(100),
    transforms.RandomHorizontalFlip(),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

img_compose = transform(img)
img_compose.size()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
torch.Size([3, 64, 64])
  • 1

2. torchvision.datasets

用来进行数据加载的,下面以CIFAR-10数据集为例,其中transform表示对数据进行预处理,对应着上面所讲

import torchvision

trainset = torchvision.datasets.CIFAR10(
    root='./dataset',  # 数据集下载的地方
    train=True,   # True表示创建训练集;False表示创建测试集
    download=True, # 如果为true,则从Internet下载数据集。如果已下载数据集,则不会再次下载
    transform=None  # 表示是否对数据进行预处理,None表示不做任何处理
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

3. torch.utils.data.DataLoader

import torch
from torch.utils.data.sampler import SubsetRandomSampler

trainloader = torch.utils.data.DataLoader(
    dataset=trainset,  # 加载torch.utils.data.Dataset对象数据或者是torchvision.datasets中的数据
    batch_size=1, # 每个batch所含样本的大小
    shuffle=False, # 是否对数据进行打乱
    sampler=SubsetRandomSampler(indices=), # 按指定下标进行取样,如果此参数被指定,shuffle参数必须为False
    drop_last=False, # 当整个数据集不能整除batch_size,False表示最后一个batch的大小会变小,True表示直接丢弃最后一个batch
    num_workers=0 # 表示加载的时候子进程数
)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

4. torch.utils.data.Dataset

from torch.utils.data.dataset import Dataset


# 基本框架
class CustomDataset(Dataset):
    def __init__(self):
    	"""
    	一些初始化过程写在这里
    	"""
        # TODO
        # 1. Initialize file paths or a list of file names. 
        pass
    def __getitem__(self, index):
    	"""
    	返回数据和标签,可以这样显示调用:
    	img, label = MyCustomDataset.__getitem__(index)
    	"""
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        pass
    def __len__(self):
    	"""
    	返回所有数据的数量
    	"""
        # You should change 9 to the total size of your dataset.
        return 9 # e.g. 9 is size of dataset

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

目前我们有一个关于图像分类的问题,数据结构如下:

在这里插入图片描述

其中一个是训练文件夹,一个测试文件夹,分类的类别数为6个,其中每个文件夹包含很多图片

如何构建Custom Dataset

  1. 分别为训练集和测试集建立两个DataFrame文件,其中DataFrame文件有两列,一列是图片的名字,令一列为标签
ImagesLabels
0.jpg0
99.jpg5
  1. 构建Custom Dataset
class INTELDataset(Dataset):
    def __init__(self, img_data,img_path,transform=None):
        self.img_path = img_path    # 数据路径
        self.transform = transform
        self.img_data = img_data  # DaraFrame
        
   
    
    def __getitem__(self, index):
        img_name = os.path.join(self.img_path,self.img_data.loc[index, 'labels'],
                                self.img_data.loc[index, 'Images'])  # 图片路径
        image = Image.open(img_name)  # 获得图片
        image = image.convert('RGB')
        label = torch.tensor(self.img_data.loc[index, 'labels'])  # 获得标签
        if self.transform is not None:
            image = self.transform(image)
        return image, label

    
    
    def __len__(self):
        return len(self.img_data)  # 数据大小
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/364843
推荐阅读
相关标签
  

闽ICP备14008679号