当前位置:   article > 正文

Pytorch学习笔记(3):图像的预处理(transforms)_pytorch图像预处理

pytorch图像预处理

 

目录

 一、torchvision:计算机视觉工具包

 二、transforms的运行机制

(1)torchvision.transforms:常用的图像预处理方法

(2)transforms运行原理 

 三、数据标准化

transforms.Normalize()

四、数据增强

 4.1 transforms—数据裁剪

(1)transforms.CentorCrop

(2)transforms.RandomCrop

(3)RandomResizedCrop

(4)FiveCrop &(5)TenCrop

4.2 transforms——翻转和旋转

(1)RandomHorizontalFlip & (2)RandomVerticalFlip

(3)RandomRotation()

 4.3 transforms—图像变换

(1)pad

(2)ColorJitter

 (3)Greyscale 

(4)RandomGreyscale

(5)RandomAffine

(6)RandomErasing

 (7)transforms.lambda

4.4 transforms——transforms方法选择操作

(1)transforms.RandomChoice

 (2)transforms.RandomApply

(3)transforms.RandomOrder

4.5 自定义transforms方法 

 五、总结:二十二种transforms操作

 一、裁剪

 二、翻转和旋转

 三、图像变换

 四、transforms的操作


前情回顾:

Pytorch学习笔记(1):基本概念、安装、张量操作、逻辑回归

Pytorch学习笔记(2):数据读取机制(DataLoader与Dataset)


 一、torchvision:计算机视觉工具包

• torchvision.transforms : 常用的图像预处理方法
• torchvision.datasets : 常用数据集的dataset实现,MNIST,CIFAR-10,ImageNet等
• torchvision.model : 常用的模型预训练,AlexNet,VGG, ResNet,GoogLeNet等


 二、transforms的运行机制

(1)torchvision.transforms:常用的图像预处理方法

数据预处理方法:数据中心化;数据标准化;缩放;裁剪;旋转;填充;噪声添加;灰度变换;线性变换;仿射变换;亮度、饱和度及对比度变换等

compose将一系列transforms方法进行有序组合包装,依次按顺序的对图像进行操作

具体代码段如下:

导入:import torchvision.transforms as transforms

  1. #训练集数据预处理
  2. train_transform = transforms.Compose([
  3. transforms.Resize((32, 32)), #缩放
  4. transforms.RandomCrop(32, padding=4), #随机裁剪
  5. transforms.ToTensor(), #转为tensor,同时进行归一化操作,将像素值的区间从0-255变为0-1
  6. transforms.Normalize(norm_mean, norm_std), #数据标准化,均值变为0,标准差变为1
  7. ])
  8. #验证集数据预处理
  9. valid_transform = transforms.Compose([ #测试时不需要数据增强
  10. transforms.Resize((32, 32)),
  11. transforms.ToTensor(),
  12. transforms.Normalize(norm_mean, norm_std),

• transforms.Compose: 将一系列的transforms方法进行有序的组合包装,依次按顺序的对图像进行操作
• transforms.Resize: 改变图像大小
• transforms.RandomCrop: 对图像进行裁剪(这个在训练集里面用,验证集就用不到了)
• transforms.ToTensor: 将图像转换成张量,同时会进行归一化的一个操作,将张量的值从0-255转到0-1
• transforms.Normalize: 将数据进行标准化


(2)transforms运行原理 

把这两个transforms操作作为参数传给Dataset,在Dataset__getitem__()方法中做图像增强。

具体代码段如下:

  1. def __getitem__(self, index):
  2. path_img, label = self.data_info[index]
  3. img = Image.open(path_img).convert('RGB') # 0~255
  4. if self.transform is not None:
  5. img = self.transform(img) # 在这里做transform,转为tensor等等
  6. return img, label

进入transforms,跳转到transforms的call函数

依次有序的从compose中调用数据处理方法

  1. def __call__(self, img):
  2. for t in self.transforms:
  3. img = t(img)
  4. return img

逻辑关系可以用下图表示:


 三、数据标准化

transforms.Normalize()

功能:逐channel的对图像进行标准化。output = (input - mean)/ std

• mean:各通道的均值

• std:各通道的标准差

• inplace:是否原地操作

具体代码段如下:

此处直接调用的torch中的normalize函数

  1. class Normalize(torch.nn.Module):
  2. def __init__(self, mean, std, inplace=False):
  3. super().__init__()
  4. self.mean = mean
  5. self.std = std
  6. self.inplace = inplace
  7. def forward(self, tensor: Tensor) -> Tensor:
  8. """
  9. Args:
  10. tensor (Tensor): Tensor image to be normalized.
  11. Returns:
  12. Tensor: Normalized Tensor image.
  13. """
  14. return F.normalize(tensor, self.mean, self.std, self.inplace)
  15. def __repr__(self):
  16. return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

进入torch的normalize函数

  1. def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor:
  2. #判断是否是tensor
  3. if not isinstance(tensor, torch.Tensor):
  4. raise TypeError('Input tensor should be a torch tensor. Got {}.'.format(type(tensor)))
  5. if tensor.ndim < 3:
  6. raise ValueError('Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = '
  7. '{}.'.format(tensor.size()))
  8. #是否进行原位操作,False则对tensor进行clone
  9. if not inplace:
  10. tensor = tensor.clone()
  11. dtype = tensor.dtype
  12. #将均值和标准差由列表格式转换为tensor格式
  13. mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
  14. std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
  15. if (std == 0).any():
  16. raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
  17. if mean.ndim == 1:
  18. mean = mean.view(-1, 1, 1)
  19. if std.ndim == 1:
  20. std = std.view(-1, 1, 1)
  21. tensor.sub_(mean).div_(std)
  22. return tensor

四、数据增强

数据增强又称为数据增广, 数据扩增,是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力, 下面是一个数据增强的小例子。

 4.1 transforms—数据裁剪

(1)transforms.CentorCrop

功能:从图像中心裁剪图片

torchvision.transforms.CenterCrop(size)

• size:所需裁剪图片尺寸


(2)transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为size的图片


• size:所需裁剪图片尺寸
• padding:设置填充大小

当为a时,上下左右均填充a个像素
当为(a, b)时,上下填充b个像素,左右填充a个像素
当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• pad_if_need:若图像小于设定size,则填充
• padding_mode:填充模式,有4种模式

  1. constant:像素值由fill设定
  2. edge:像素值由图像边缘像素决定
  3. reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
  4. symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]

• fill:constant时,设置填充的像素值

具体代码段如下:

  1. # 测试RandomCrop随机裁剪
  2. trans_random = transforms.RandomCrop(300)
  3. trans_compose_2 = transforms.Compose([trans_random, trans_totensor])
  4. for i in range(10): # 0裁剪10次
  5. img_crop = trans_compose_2(img)
  6. writer.add_image("RandomCrop", img_crop, i)

(3)RandomResizedCrop

功能:随机大小、长宽比裁剪图片

• size:所需裁剪图片尺寸
• scale:随机裁剪面积比例,默认(0.08,1)    (在0.08-1之间选择一个比例进行裁剪)
• ratio:随机长宽比,默认(3/4,4/3)
• interpolation:插值方法        (由于裁剪之后的图片可能会小于size,故进行插值操作)

  1. PIL.Image.NEAREST        
  2. PIL.Image.BILINEAR
  3. PIL.Image.BICUBIC 

(4)FiveCrop &(5)TenCrop

功能:在图像的上下左右及中心裁剪出尺寸为size的5张图片,TenCrop还在这5张图片的基础上再水平或者垂直镜像得到10张图片

• size:所需裁剪图片尺寸

• vertical_flip:是否垂直翻转


4.2 transforms——翻转和旋转

(1)RandomHorizontalFlip & (2)RandomVerticalFlip

功能:依概率水平(左右)或垂直(上下)翻转图片

  • p:翻转概率

(3)RandomRotation()

功能:随机旋转图片

• degrees:旋转角度
当为a时,在(-a,a)之间选择旋转角度
当为(a, b)时,在(a, b)之间选择旋转角度
• resample:重采样方法
• expand:是否扩大图片,以保持原图信息
• center:旋转点设置,默认中心旋转

完整代码:

  1. # -*- coding: utf-8 -*-
  2. import os
  3. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  4. import numpy as np
  5. import torch
  6. import random
  7. from torch.utils.data import DataLoader
  8. import torchvision.transforms as transforms
  9. from PIL import Image
  10. from matplotlib import pyplot as plt
  11. path_lenet = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "model", "lenet.py"))
  12. path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
  13. assert os.path.exists(path_lenet), "{}不存在,请将lenet.py文件放到 {}".format(path_lenet, os.path.dirname(path_lenet))
  14. assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))
  15. import sys
  16. hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
  17. sys.path.append(hello_pytorch_DIR)
  18. from tools.my_dataset import RMBDataset
  19. from tools.common_tools import set_seed, transform_invert
  20. set_seed(1) # 设置随机种子
  21. # 参数设置
  22. MAX_EPOCH = 10
  23. BATCH_SIZE = 1
  24. LR = 0.01
  25. log_interval = 10
  26. val_interval = 1
  27. rmb_label = {"1": 0, "100": 1}
  28. # ============================ step 1/5 数据 ============================
  29. split_dir = os.path.abspath(os.path.join("..", "..", "data", "RMB_data", "rmb_split"))
  30. if not os.path.exists(split_dir):
  31. raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))
  32. train_dir = os.path.join(split_dir, "train")
  33. valid_dir = os.path.join(split_dir, "valid")
  34. norm_mean = [0.485, 0.456, 0.406]
  35. norm_std = [0.229, 0.224, 0.225]
  36. train_transform = transforms.Compose([
  37. transforms.Resize((224, 224)), #统一图片尺寸
  38. # 1 CenterCrop
  39. # transforms.CenterCrop(196), # 512
  40. # 2 RandomCrop
  41. # transforms.RandomCrop(224, padding=16),
  42. # transforms.RandomCrop(224, padding=(16, 64)),
  43. # transforms.RandomCrop(224, padding=16, fill=(255, 0, 0)),
  44. # transforms.RandomCrop(512, pad_if_needed=True), # pad_if_needed=True
  45. # transforms.RandomCrop(224, padding=64, padding_mode='edge'),
  46. # transforms.RandomCrop(224, padding=64, padding_mode='reflect'),
  47. # transforms.RandomCrop(1024, padding=1024, padding_mode='symmetric'),
  48. # 3 RandomResizedCrop
  49. # transforms.RandomResizedCrop(size=224, scale=(0.5, 0.5)),
  50. # 4 FiveCrop
  51. # transforms.FiveCrop(112),
  52. # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),
  53. # 5 TenCrop
  54. # transforms.TenCrop(112, vertical_flip=False),
  55. # transforms.Lambda(lambda crops: torch.stack([(transforms.ToTensor()(crop)) for crop in crops])),
  56. # 1 Horizontal Flip
  57. # transforms.RandomHorizontalFlip(p=1),
  58. # 2 Vertical Flip
  59. # transforms.RandomVerticalFlip(p=0.5),
  60. # 3 RandomRotation
  61. # transforms.RandomRotation(90),
  62. # transforms.RandomRotation((90), expand=True),
  63. # transforms.RandomRotation(30, center=(0, 0)),
  64. # transforms.RandomRotation(30, center=(0, 0), expand=True), # expand only for center rotation
  65. transforms.ToTensor(),
  66. transforms.Normalize(norm_mean, norm_std),
  67. ])
  68. valid_transform = transforms.Compose([
  69. transforms.Resize((224, 224)),
  70. transforms.ToTensor(),
  71. transforms.Normalize(norm_mean, norm_std)
  72. ])
  73. # 构建MyDataset实例
  74. train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
  75. valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)
  76. # 构建DataLoder
  77. train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
  78. valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
  79. # ============================ step 5/5 训练 ============================
  80. for epoch in range(MAX_EPOCH):
  81. for i, data in enumerate(train_loader):
  82. inputs, labels = data # B C H W
  83. img_tensor = inputs[0, ...] # C H W
  84. #invert函数对transforms进行逆操作,可以将浮点数据转为img,便于观察
  85. img = transform_invert(img_tensor, train_transform)
  86. plt.imshow(img)
  87. plt.show()
  88. plt.pause(0.5)
  89. plt.close()
  90. # FiveCrop 和 TenCrop的可视化操作,因为输出为5维
  91. # bs, ncrops, c, h, w = inputs.shape
  92. # for n in range(ncrops):
  93. # img_tensor = inputs[0, n, ...] # C H W
  94. # img = transform_invert(img_tensor, train_transform)
  95. # plt.imshow(img)
  96. # plt.show()
  97. # plt.pause(1)

 4.3 transforms—图像变换

(1)pad

功能:对图片边缘进行填充

• padding:设置填充大小
 当为a时,上下左右均填充a个像素
 当为(a,b)时,上下填充b个像素,左右填充a个像素
 当为(a,b,c,d)时,左,上,右,下分别填充a,b,c,d
• padding_mode:填充模式,有四种模式,constant、edge、reflect和symmetric(具体请见三.2.(2)节)
• fill:constant时, 设置填充的像素值,(R,G,B)or(Gray)
padding_mode优先级高于fill


(2)ColorJitter

功能:调整亮度、对比度、饱和度和色相, 这个是比较实用的方法。

• brightness:亮度调整因子

  • 当为a时,从[max(0,1-a),1+a]中随机选择
  • 当为(a,b)时,从[a,b]中选择

• contrast:对比度参数,同brightness
• saturation:饱和度参数,同brightness
• hue:色相参数

  • 当为a时,从[-a,a]中选择参数,注:0<=a<=0.5
  • 当为(a,b)时,从[a,b]中选择参数,注:-0.5<=a<=b<=0.5

 (3)Greyscale 

 功能:将图片转换为灰度图

• num_output_channels: 输出的通道数。只能设置为 1 或者 3 (如果在后面使用了transforms.Normalize,则要设置为 3,因为transforms.Normalize只能接收 3 通道的输入)


(4)RandomGreyscale

功能:依概率将图片转换为灰度图

  • num_output_channels:输出通道数,只能设1或3
  • p:概率值,图像被转换为灰度图的概率,当p=1,则等价于Greyscale

(5)RandomAffine

功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转,平移,缩放,错切和翻转

• degrees:旋转角度设置
• translate:平移区间设置,如(a,b),a设置宽(width),b设置高(height),图像在宽维度平移的区间为  -img_width * a < dx < img_width * a
• scale:缩放比例(以面积为单位)
• fill_color:填充颜色设置
• shear:错切角度设置,有水平错切和垂直错切

  • 若为a,则仅在x轴错切,错切角度在(-a,a)之间
  • 若为(a,b),则设置x轴角度,b设置y的角度
  • 若为(a,b,c,d),则a,b设置x轴角度,c,d设置y轴角度

• resample:重采样方式,有NEAREST、BILINEAR、BICUBIC


(6)RandomErasing

功能:对图像进行随机遮挡

  • p:概率值,执行该操作的概率
  • scale:遮挡区域的面积
  • ratio:遮挡区域长宽比
  • value:设置遮挡区域的像素值,(R,G,B)or(Grey)

注意事项:执行Erasing是对tensor进行操作的,故需要把输入转为张量的类型 ,transforms.ToTensor() 

遮挡效果如下:


 (7)transforms.lambda

功能:用户自定义lambda方法

• lambd:lambda匿名函数
  • lambda [arg1[,arg2,...,argn]] : expression 
TenCrop输出的结果是tuple类型,故需要对输出结果转换为tensor,就可以用到lambda函数

stack将返回的张量进行拼接,输出为4D的张量,stack会创建一个维度将张量进行拼接

举个栗子:

  1. transforms.TenCrop(200, vertical_flip=True),
  2. transforms.Lambda(lambda crops: torch.stack([transforms.Totensor()(crop) for crop in crops])),

4.4 transforms——transforms方法选择操作

(1)transforms.RandomChoice

功能:从一系列transforms方法中随机挑选一个


 (2)transforms.RandomApply

功能:依据概率执行一组transforms操作 


(3)transforms.RandomOrder

功能:对一组transforms操作打乱顺序


4.5 自定义transforms方法 

自定义transforms要素:

  • 仅接收一个参数,返回一个参数
  • 注意上下游的输入与输出

我们对Compose里面的这些transforms方法执行一个for循环,每次挑取一个方法进行执行。 也就是transforms方法仅接收一个参数,返回一个参数,然后就是for循环中,上一个transforms的输出正好是下一个transforms的输入,所以数据类型要注意匹配。 这就是自定义transforms的两个要素。
下面给出一个自定义transforms的结构:

数据增强策略原则: 让训练集与测试集更接近

  • 空间位置上: 可以选择平移
  • 色彩上: 灰度图,色彩抖动
  • 形状: 仿射变换
  • 上下文场景: 遮挡,填充

 五、总结:二十二种transforms操作

 一、裁剪

• transforms.CenterCrop
• transforms.RandomCrop
• transforms.RandomResizedCrop
• transforms.FiveCrop
• transforms.TenCrop

二、翻转和旋转

• transforms.RandomHorizontalFlip
• transforms.RandomVerticalFlip
• transforms.RandomRotation

三、图像变换

• transforms.Pad
• transforms.ColorJitter
• transforms.Grayscale
• transforms.RandomGrayscale
• transforms.RandomAffine
• transforms.LinearTransformation
• transforms.RandomErasing
• transforms.Lambda
• transforms.Resize
• transforms.Totensor
• transforms.Normalize

四、transforms的操作

• transforms.RandomChoice
• transforms.RandomApply
• transforms.RandomOrd


本文参考:

[PyTorch 学习笔记] 2.3 二十二种 transforms 图片数据预处理方法 - 知乎 (zhihu.com)

Pytorch基础学习(第二章-Pytorch数据处理)

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号