当前位置:   article > 正文

pytorch实践小结

pytorch实践小结

根据土堆教程学习情况,即学即练。要求:获取桃(水蜜桃,油桃和黄桃)数据集的收集,划分训练集和测试集并完成分类任务。

1. 获取数据(爬虫)

代码来源: http://t.csdnimg.cn/CcEW6

  1. # -*- coding: utf-8 -*-
  2. """
  3. Created on Wed Mar 29 10:17:50 2023
  4. @author: MatpyMaster
  5. """
  6. import requests
  7. import os
  8. import re
  9. def get_images_from_baidu(keyword, page_num, save_dir):
  10. header = {
  11. 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/78.0.3904.108 Safari/537.36'}
  12. # 请求的 url
  13. url = 'https://image.baidu.com/search/acjson?'
  14. n = 0
  15. for pn in range(0, 30 * page_num, 30):
  16. # 请求参数
  17. param = {'tn': 'resultjson_com',
  18. 'logid': '7603311155072595725',
  19. 'ipn': 'rj',
  20. 'ct': 201326592,
  21. 'is': '',
  22. 'fp': 'result',
  23. 'queryWord': keyword,
  24. 'cl': 2,
  25. 'lm': -1,
  26. 'ie': 'utf-8',
  27. 'oe': 'utf-8',
  28. 'adpicid': '',
  29. 'st': -1,
  30. 'z': '',
  31. 'ic': '',
  32. 'hd': '',
  33. 'latest': '',
  34. 'copyright': '',
  35. 'word': keyword,
  36. 's': '',
  37. 'se': '',
  38. 'tab': '',
  39. 'width': '',
  40. 'height': '',
  41. 'face': 0,
  42. 'istype': 2,
  43. 'qc': '',
  44. 'nc': '1',
  45. 'fr': '',
  46. 'expermode': '',
  47. 'force': '',
  48. 'cg': '', # 这个参数没公开,但是不可少
  49. 'pn': pn, # 显示:30-60-90
  50. 'rn': '30', # 每页显示 30 条
  51. 'gsm': '1e',
  52. '1618827096642': ''
  53. }
  54. request = requests.get(url=url, headers=header, params=param)
  55. if request.status_code == 200:
  56. print('Request success.')
  57. request.encoding = 'utf-8'
  58. # 正则方式提取图片链接
  59. html = request.text
  60. image_url_list = re.findall('"thumbURL":"(.*?)",', html, re.S)
  61. if not os.path.exists(save_dir):
  62. os.makedirs(save_dir)
  63. for image_url in image_url_list:
  64. image_data = requests.get(url=image_url, headers=header).content
  65. with open(os.path.join(save_dir, f'{n:06d}.jpg'), 'wb') as fp:
  66. fp.write(image_data)
  67. n = n + 1
  68. if __name__ == "__main__":
  69. keyword = '黄桃'
  70. page_num = 1
  71. page_num = int(page_num)
  72. save_dir = '.\\图片\\' + keyword
  73. get_images_from_baidu(keyword, page_num, save_dir)

2. train.py

  1. import os
  2. import time
  3. from torch.utils.tensorboard import SummaryWriter
  4. import torch.optim
  5. from torch.utils.data import Dataset, DataLoader
  6. from PIL import Image
  7. import torchvision
  8. from torch import nn
  9. #指定设备
  10. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  11. #获取数据
  12. class PeachDate(Dataset):
  13. def __init__(self,root_dir,label_dir):
  14. self.root_dir = root_dir
  15. self.label_dir = label_dir
  16. self.path = os.path.join(root_dir,label_dir) #完整路径
  17. self.img_list = os.listdir(self.path) #图片列表
  18. self.label_mapping = {'honey_peach': 0, 'Huangtao': 1, 'nectarine': 2}
  19. def __getitem__(self, idx):
  20. img_name = self.img_list[idx]
  21. img_path = os.path.join(self.root_dir,self.label_dir,img_name) #单张图片完整路径
  22. img_read = Image.open(img_path) #读取图片
  23. trans_resize = torchvision.transforms.Resize((32,32)) #统一尺寸
  24. trans_tensor = torchvision.transforms.ToTensor() #转换图片类型
  25. if img_read.mode != 'RGB': # 确保通道数为3(RGB)
  26. img_read = img_read.convert('RGB')
  27. img_resize = trans_resize(img_read)
  28. img = trans_tensor(img_resize)
  29. label_str = self.label_dir
  30. label_int = self.label_mapping[label_str]
  31. return img, label_int
  32. def __len__(self):
  33. return len(self.img_list)
  34. train_root_dir = "dataset/train"
  35. test_root_dir = "dataset/test"
  36. train_honeypeach = PeachDate(train_root_dir,"honey_peach")
  37. train_Huangtao = PeachDate(train_root_dir,"Huangtao")
  38. train_nectarine = PeachDate(train_root_dir,"nectarine")
  39. train_data = train_honeypeach + train_Huangtao + train_nectarine
  40. test_honeypeach = PeachDate(test_root_dir,"honey_peach")
  41. test_Huangtao = PeachDate(test_root_dir,"Huangtao")
  42. test_nectarine = PeachDate(test_root_dir,"nectarine")
  43. test_data = test_honeypeach + test_Huangtao + test_nectarine
  44. #数据长度
  45. train_data_size = len(train_data)
  46. test_data_size = len(test_data)
  47. print("训练集的长度为:{}".format(train_data_size),"测试集的长度为:{}".format(test_data_size))
  48. # 打包数据
  49. train_dataloader = DataLoader(train_data,8,shuffle=True)
  50. test_dataloader = DataLoader(test_data,4,shuffle=True)
  51. #创建网络模型 (CIF改写)
  52. class PeachNn(nn.Module):
  53. def __init__(self):
  54. super(PeachNn, self).__init__()
  55. self.model = nn.Sequential(
  56. nn.Conv2d(3,32,5,1,padding=2),
  57. nn.MaxPool2d(2),
  58. nn.Conv2d(32,32,5,1,padding=2),
  59. nn.MaxPool2d(2),
  60. nn.Conv2d(32,64,5,1,padding=2),
  61. nn.MaxPool2d(2),
  62. nn.Flatten(),
  63. nn.Linear(64*4*4,64),
  64. nn.Linear(64,10),
  65. nn.Linear(10,3)
  66. )
  67. def forward(self,x):
  68. x = self.model(x)
  69. return x
  70. peachnn = PeachNn()
  71. peachnn.to(device)
  72. #损失函数
  73. loss_fn = nn.CrossEntropyLoss()
  74. loss_fn.to(device)
  75. #优化器
  76. learing_rate = 0.01
  77. optimzer = torch.optim.SGD(peachnn.parameters(),lr=learing_rate)
  78. #其他参数
  79. total_train_step = 0
  80. total_test_step = 0
  81. epoch = 10
  82. #添加tensorboard
  83. writer = SummaryWriter("./logs")
  84. #开始计时
  85. start_time = time.time()
  86. for i in range(epoch):
  87. print("-------第{}轮训练开始".format(i+1))
  88. #训练
  89. peachnn.train()
  90. for data in train_dataloader:
  91. imgs,targets = data
  92. imgs = imgs.to(device)
  93. targets = targets.to(device)
  94. outputs = peachnn(imgs)
  95. loss = loss_fn(outputs,targets) #损失
  96. #优化器调优
  97. optimzer.zero_grad()
  98. loss.backward()
  99. optimzer.step()
  100. total_train_step = total_train_step+1
  101. if total_train_step % 10 == 0: #每训练十次打印一次
  102. end_time = time.time()
  103. print("训练时长为:{}".format(end_time - start_time))
  104. print('训练第{}次,loss:{}'.format(total_train_step, loss.item()))
  105. writer.add_scalar("train_loss", loss, total_train_step) # 将损失写入tensorboard
  106. peachnn.eval() #测试开始
  107. total_test_loss = 0
  108. total_accuracy = 0
  109. with torch.no_grad():
  110. for data in test_dataloader:
  111. imgs,targets = data
  112. imgs = imgs.to(device)
  113. targets = targets.to(device)
  114. outputs = peachnn(imgs)
  115. loss = loss_fn(outputs,targets)
  116. accuracy = (outputs.argmax(1) == targets).sum()
  117. total_test_loss = total_test_loss + loss.item()
  118. total_accuracy = total_accuracy + accuracy
  119. print("整体测试集上的损失为:{}".format(total_test_loss))
  120. print("测试集上的准确率为:{}".format(total_accuracy/test_data_size))
  121. writer.add_scalar("total_test_loss",total_test_loss,total_test_step)
  122. writer.add_scalar("total_accuracy",total_accuracy,total_test_step)
  123. total_test_step = total_test_step+1
  124. #保存模型
  125. torch.save(peachnn,"peachnn_{}.pt".format(i))
  126. print("第{}轮模型已保存".format(i))
  127. writer.close()

3. test.py

  1. import torch
  2. from PIL import Image
  3. import torchvision
  4. from torch import nn
  5. img_path = "dataset/test/Huangtao/000002.jpg"
  6. img_read = Image.open(img_path)
  7. img = img_read.convert("RGB")
  8. transforms = torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)),torchvision.transforms.ToTensor()])
  9. image = transforms(img)
  10. print(image.shape)
  11. #创建网络模型 (CIF改写)
  12. class PeachNn(nn.Module):
  13. def __init__(self):
  14. super(PeachNn, self).__init__()
  15. self.model = nn.Sequential(
  16. nn.Conv2d(3,32,5,1,padding=2),
  17. nn.MaxPool2d(2),
  18. nn.Conv2d(32,32,5,1,padding=2),
  19. nn.MaxPool2d(2),
  20. nn.Conv2d(32,64,5,1,padding=2),
  21. nn.MaxPool2d(2),
  22. nn.Flatten(),
  23. nn.Linear(64*4*4,64),
  24. nn.Linear(64,10),
  25. nn.Linear(10,3)
  26. )
  27. def forward(self,x):
  28. x = self.model(x)
  29. return x
  30. peahcnn = PeachNn()
  31. # model = torch.load("peachnn_9.pt",map_location=torch.device("cuda"))
  32. model = torch.load("peachnn_9.pt")
  33. print(model)
  34. image = torch.reshape(image,(1,3,32,32))
  35. model.eval()
  36. with torch.no_grad():
  37. output = peahcnn(image)
  38. print(output)
  39. print(output.argmax(1))

特别感谢

CSDN 
  土堆老师 & MatpyMaster

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/228789
推荐阅读
相关标签
  

闽ICP备14008679号