赞
踩
根据土堆教程学习情况,即学即练。要求:获取桃(水蜜桃,油桃和黄桃)数据集的收集,划分训练集和测试集并完成分类任务。
1. 获取数据(爬虫)
代码来源: http://t.csdnimg.cn/CcEW6
- # -*- coding: utf-8 -*-
- """
- Created on Wed Mar 29 10:17:50 2023
- @author: MatpyMaster
- """
- import requests
- import os
- import re
-
-
- def get_images_from_baidu(keyword, page_num, save_dir):
- header = {
- 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36'}
- # 请求的 url
- url = 'https://image.baidu.com/search/acjson?'
- n = 0
- for pn in range(0, 30 * page_num, 30):
- # 请求参数
- param = {'tn': 'resultjson_com',
- 'logid': '7603311155072595725',
- 'ipn': 'rj',
- 'ct': 201326592,
- 'is': '',
- 'fp': 'result',
- 'queryWord': keyword,
- 'cl': 2,
- 'lm': -1,
- 'ie': 'utf-8',
- 'oe': 'utf-8',
- 'adpicid': '',
- 'st': -1,
- 'z': '',
- 'ic': '',
- 'hd': '',
- 'latest': '',
- 'copyright': '',
- 'word': keyword,
- 's': '',
- 'se': '',
- 'tab': '',
- 'width': '',
- 'height': '',
- 'face': 0,
- 'istype': 2,
- 'qc': '',
- 'nc': '1',
- 'fr': '',
- 'expermode': '',
- 'force': '',
- 'cg': '', # 这个参数没公开,但是不可少
- 'pn': pn, # 显示:30-60-90
- 'rn': '30', # 每页显示 30 条
- 'gsm': '1e',
- '1618827096642': ''
- }
- request = requests.get(url=url, headers=header, params=param)
- if request.status_code == 200:
- print('Request success.')
- request.encoding = 'utf-8'
- # 正则方式提取图片链接
- html = request.text
- image_url_list = re.findall('"thumbURL":"(.*?)",', html, re.S)
-
- if not os.path.exists(save_dir):
- os.makedirs(save_dir)
-
- for image_url in image_url_list:
- image_data = requests.get(url=image_url, headers=header).content
- with open(os.path.join(save_dir, f'{n:06d}.jpg'), 'wb') as fp:
- fp.write(image_data)
- n = n + 1
-
-
- if __name__ == "__main__":
- keyword = '黄桃'
- page_num = 1
- page_num = int(page_num)
- save_dir = '.\\图片\\' + keyword
- get_images_from_baidu(keyword, page_num, save_dir)

2. train.py
- import os
- import time
- from torch.utils.tensorboard import SummaryWriter
- import torch.optim
- from torch.utils.data import Dataset, DataLoader
- from PIL import Image
- import torchvision
- from torch import nn
-
- #指定设备
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- #获取数据
- class PeachDate(Dataset):
- def __init__(self,root_dir,label_dir):
- self.root_dir = root_dir
- self.label_dir = label_dir
- self.path = os.path.join(root_dir,label_dir) #完整路径
- self.img_list = os.listdir(self.path) #图片列表
- self.label_mapping = {'honey_peach': 0, 'Huangtao': 1, 'nectarine': 2}
-
- def __getitem__(self, idx):
- img_name = self.img_list[idx]
- img_path = os.path.join(self.root_dir,self.label_dir,img_name) #单张图片完整路径
- img_read = Image.open(img_path) #读取图片
- trans_resize = torchvision.transforms.Resize((32,32)) #统一尺寸
- trans_tensor = torchvision.transforms.ToTensor() #转换图片类型
- if img_read.mode != 'RGB': # 确保通道数为3(RGB)
- img_read = img_read.convert('RGB')
- img_resize = trans_resize(img_read)
- img = trans_tensor(img_resize)
- label_str = self.label_dir
- label_int = self.label_mapping[label_str]
- return img, label_int
-
- def __len__(self):
- return len(self.img_list)
-
- train_root_dir = "dataset/train"
- test_root_dir = "dataset/test"
-
- train_honeypeach = PeachDate(train_root_dir,"honey_peach")
- train_Huangtao = PeachDate(train_root_dir,"Huangtao")
- train_nectarine = PeachDate(train_root_dir,"nectarine")
- train_data = train_honeypeach + train_Huangtao + train_nectarine
-
- test_honeypeach = PeachDate(test_root_dir,"honey_peach")
- test_Huangtao = PeachDate(test_root_dir,"Huangtao")
- test_nectarine = PeachDate(test_root_dir,"nectarine")
- test_data = test_honeypeach + test_Huangtao + test_nectarine
-
- #数据长度
- train_data_size = len(train_data)
- test_data_size = len(test_data)
- print("训练集的长度为:{}".format(train_data_size),"测试集的长度为:{}".format(test_data_size))
-
- # 打包数据
- train_dataloader = DataLoader(train_data,8,shuffle=True)
- test_dataloader = DataLoader(test_data,4,shuffle=True)
-
-
- #创建网络模型 (CIF改写)
- class PeachNn(nn.Module):
- def __init__(self):
- super(PeachNn, self).__init__()
- self.model = nn.Sequential(
- nn.Conv2d(3,32,5,1,padding=2),
- nn.MaxPool2d(2),
- nn.Conv2d(32,32,5,1,padding=2),
- nn.MaxPool2d(2),
- nn.Conv2d(32,64,5,1,padding=2),
- nn.MaxPool2d(2),
- nn.Flatten(),
- nn.Linear(64*4*4,64),
- nn.Linear(64,10),
- nn.Linear(10,3)
- )
-
- def forward(self,x):
- x = self.model(x)
- return x
-
- peachnn = PeachNn()
- peachnn.to(device)
-
- #损失函数
- loss_fn = nn.CrossEntropyLoss()
- loss_fn.to(device)
-
- #优化器
- learing_rate = 0.01
- optimzer = torch.optim.SGD(peachnn.parameters(),lr=learing_rate)
-
- #其他参数
- total_train_step = 0
- total_test_step = 0
- epoch = 10
-
- #添加tensorboard
- writer = SummaryWriter("./logs")
-
- #开始计时
- start_time = time.time()
-
- for i in range(epoch):
- print("-------第{}轮训练开始".format(i+1))
-
- #训练
- peachnn.train()
- for data in train_dataloader:
- imgs,targets = data
- imgs = imgs.to(device)
- targets = targets.to(device)
- outputs = peachnn(imgs)
- loss = loss_fn(outputs,targets) #损失
-
- #优化器调优
- optimzer.zero_grad()
- loss.backward()
- optimzer.step()
-
- total_train_step = total_train_step+1
-
- if total_train_step % 10 == 0: #每训练十次打印一次
- end_time = time.time()
- print("训练时长为:{}".format(end_time - start_time))
- print('训练第{}次,loss:{}'.format(total_train_step, loss.item()))
-
- writer.add_scalar("train_loss", loss, total_train_step) # 将损失写入tensorboard
-
- peachnn.eval() #测试开始
- total_test_loss = 0
- total_accuracy = 0
- with torch.no_grad():
- for data in test_dataloader:
- imgs,targets = data
- imgs = imgs.to(device)
- targets = targets.to(device)
- outputs = peachnn(imgs)
- loss = loss_fn(outputs,targets)
- accuracy = (outputs.argmax(1) == targets).sum()
- total_test_loss = total_test_loss + loss.item()
- total_accuracy = total_accuracy + accuracy
-
- print("整体测试集上的损失为:{}".format(total_test_loss))
- print("测试集上的准确率为:{}".format(total_accuracy/test_data_size))
-
- writer.add_scalar("total_test_loss",total_test_loss,total_test_step)
- writer.add_scalar("total_accuracy",total_accuracy,total_test_step)
-
- total_test_step = total_test_step+1
-
- #保存模型
- torch.save(peachnn,"peachnn_{}.pt".format(i))
- print("第{}轮模型已保存".format(i))
-
- writer.close()

3. test.py
- import torch
- from PIL import Image
- import torchvision
- from torch import nn
-
- img_path = "dataset/test/Huangtao/000002.jpg"
- img_read = Image.open(img_path)
- img = img_read.convert("RGB")
-
- transforms = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
- image = transforms(img)
-
- print(image.shape)
-
- #创建网络模型 (CIF改写)
- class PeachNn(nn.Module):
- def __init__(self):
- super(PeachNn, self).__init__()
- self.model = nn.Sequential(
- nn.Conv2d(3,32,5,1,padding=2),
- nn.MaxPool2d(2),
- nn.Conv2d(32,32,5,1,padding=2),
- nn.MaxPool2d(2),
- nn.Conv2d(32,64,5,1,padding=2),
- nn.MaxPool2d(2),
- nn.Flatten(),
- nn.Linear(64*4*4,64),
- nn.Linear(64,10),
- nn.Linear(10,3)
- )
-
- def forward(self,x):
- x = self.model(x)
- return x
-
- peahcnn = PeachNn()
- # model = torch.load("peachnn_9.pt",map_location=torch.device("cuda"))
- model = torch.load("peachnn_9.pt")
- print(model)
-
- image = torch.reshape(image,(1,3,32,32))
- model.eval()
- with torch.no_grad():
- output = peahcnn(image)
-
- print(output)
- print(output.argmax(1))

特别感谢
CSDN
土堆老师 & MatpyMaster
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。