赞
踩
import torch from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data): self.data = data def __getitem__(self, index): # 根据索引获取样本 return self.data[index] def __len__(self): # 返回数据集大小 return len(self.data) # 创建数据集对象 data = [1, 2, 3, 4, 5] dataset = MyDataset(data) # 根据索引获取样本 sample = dataset[2] print(sample)
import torch
from torchvision import datasets, transforms
# 定义数据转换
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5,), (0.5,)) # 标准化图像
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
from torchvision import datasets,transforms from torch.utils.data import DataLoader # transform.Compose是PyTorch中的一个类,用于将多个图像变换操作组合在一起。它的作用是将这些操作按照顺序依次应用于输入的图像数据。 trans = transforms.Compose([ np.float32, transforms.ToTensor(), fixed_image_standardization ]) dataset = datasets.ImageFolder(data_dir, transform=trans) loader = DataLoader( dataset, num_workers=workers, batch_size=batch_size, collate_fn=training.collate_pil )
# 定义输入图像的数据加载器
mytransform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
dataset = datasets.ImageFolder(data_dir, transform=mytransform) #对于图像,必须transform转成Tensor,才能for input,label in train_loader读取
print(dataset)
print(len(dataset))
print(len(dataset.imgs))
print(len(dataset.classes))
print(dataset.classes[-1])
print(dataset.classes)
print(dataset.imgs)
\root
\cls1
\img1.png
\img2.png
\cls2
\img1.png
\img2.png
\cls3
\img1.png
\img2.png
# img_list_1=[img for (img,idx) in dataset.imgs]
# with open("img_list_1.pkl","wb") as file:
# pickle.dump(img_list_1,file)
import torch
from torchvision import datasets, transforms
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)
# 使用数据加载器迭代样本
for images, labels in train_loader:
# 训练模型的代码
...
link:加载 in batch的进程数
from torchvision import transforms
# 定义图像预处理操作
transform = transforms.Compose([
transforms.Resize((256, 256)), # 缩放图像大小为 (256, 256)
transforms.RandomCrop((224, 224)), # 随机裁剪图像为 (224, 224)
transforms.RandomHorizontalFlip(), # 随机水平翻转图像
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化图像
])
# 对图像进行预处理
image = transform(image)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。