当前位置:   article > 正文

Pytorch for training1——read data/image

Pytorch for training1——read data/image

blog

torch.utils.data.Dataset

  1. create dataset with class torch.utils.data.Dataset automaticly
import torch
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        # 根据索引获取样本
        return self.data[index]

    def __len__(self):
        # 返回数据集大小
        return len(self.data)

# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)

# 根据索引获取样本
sample = dataset[2]
print(sample)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

torchvision.datasets

  1. load data from classic dataset
import torch
from torchvision import datasets, transforms

# 定义数据转换
transform = transforms.Compose([
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5,), (0.5,))  # 标准化图像
])

# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

2. load data from Imagefolder with transform

from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# transform.Compose是PyTorch中的一个类,用于将多个图像变换操作组合在一起。它的作用是将这些操作按照顺序依次应用于输入的图像数据。
trans = transforms.Compose([
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])

dataset = datasets.ImageFolder(data_dir, transform=trans)
loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    collate_fn=training.collate_pil
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

在这里插入图片描述

3. Introduction of Imagefolder

# 定义输入图像的数据加载器
mytransform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
dataset = datasets.ImageFolder(data_dir, transform=mytransform)	#对于图像,必须transform转成Tensor,才能for input,label in train_loader读取
print(dataset)
print(len(dataset))
print(len(dataset.imgs))
print(len(dataset.classes))
print(dataset.classes[-1])
print(dataset.classes)
print(dataset.imgs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

在这里插入图片描述

\root
	\cls1
		\img1.png
		\img2.png
	\cls2
		\img1.png
		\img2.png
	\cls3
		\img1.png
		\img2.png
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

(img,cls) in dataset.imgs

# img_list_1=[img for (img,idx) in dataset.imgs]
# with open("img_list_1.pkl","wb") as file:
#     pickle.dump(img_list_1,file)
  • 1
  • 2
  • 3

DataLoader

  1. to loader data from example of torch.utils.data.Dataset
import torch
from torchvision import datasets, transforms

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

# 使用数据加载器迭代样本
for images, labels in train_loader:
    # 训练模型的代码
    ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

num_workers

link:加载 in batch的进程数

torchvision.transforms

from torchvision import transforms

# 定义图像预处理操作
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 缩放图像大小为 (256, 256)
    transforms.RandomCrop((224, 224)),  # 随机裁剪图像为 (224, 224)
    transforms.RandomHorizontalFlip(),  # 随机水平翻转图像
    transforms.ToTensor(),  # 将图像转换为张量
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化图像
])

# 对图像进行预处理
image = transform(image)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

Image play with cv2,PIL.Image

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/345021
推荐阅读
相关标签
  

闽ICP备14008679号