当前位置:   article > 正文

pytorch系列(六):各种优化器的性能比较_pytorch优化器效果对比

pytorch优化器效果对比
  1. import torch
  2. import torch.utils.data as Data
  3. import torch.nn.functional as f
  4. import matplotlib.pyplot as plt
  5. #指定超参数
  6. LR=0.01#学习率
  7. BATCH_SIZE=32#批数据的大小
  8. EPOCH=12#迭代次数
  9. #构造数据集
  10. x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
  11. y=x.pow(2)+0.1*(torch.normal(torch.zeros(*x.size())))
  12. #打印数据
  13. plt.scatter(x.data.numpy(),y.data.numpy(),c='r')
  14. plt.show()
  15. #使用dataloader工具进行数据的处理
  16. torch_dataset=Data.TensorDataset(x,y)#将x和y转换成torch可识别的数据集
  17. loader=Data.DataLoader(
  18. dataset=torch_dataset,
  19. batch_size=BATCH_SIZE,
  20. shuffle=True,
  21. num_workers=2
  22. )
  23. #构造网络结构并为每一个优化器优化一个神经网络
  24. class Net(torch.nn.Module):
  25. def __init__(self):
  26. super(Net,self).__init__()
  27. self.hidden=torch.nn.Linear(1,20)
  28. self.output=torch.nn.Linear(20,1)
  29. #前向传播
  30. def forward(self,x):
  31. x=f.relu(self.hidden(x))
  32. x=self.output(x)
  33. return x
  34. #每一个优化器对应一个网络结构
  35. net_SGD=Net()
  36. net_Momentum=Net()
  37. net_RMSprop=Net()
  38. net_Adam=Net()
  39. #放到一个列表中
  40. nets=[net_SGD,net_Momentum,net_RMSprop,net_Adam]
  41. #API化每一个优化器
  42. opt_SGD=torch.optim.SGD(net_SGD.parameters(),lr=LR)
  43. opt_Momentum=torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.8)
  44. opt_RMSprop=torch.optim.RMSprop(net_RMSprop.parameters(),lr=LR,alpha=0.9)
  45. opt_Adam=torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))
  46. #用一个列表存放每一个优化器
  47. optimizers=[opt_SGD,opt_Momentum,opt_RMSprop,opt_Adam]
  48. #指定损失函数
  49. loss_func=torch.nn.MSELoss()
  50. #用一个两层列表记录各个优化器的loss
  51. loss_his=[[],[],[],[]]
  52. #训练 可视化
  53. for epoch in range(EPOCH):
  54. print(epoch)
  55. for step,(batch_x,batch_y) in enumerate(loader):
  56. #对于每一个优化器,优化他的神经网络
  57. for net,opt,l_his in zip(nets,optimizers,loss_his):
  58. output=net(batch_x)#对每一个网络丢入数据
  59. loss=loss_func(output,batch_y)#计算预测值和真实值之间的误差
  60. opt.zero_grad()#梯度清零
  61. loss.backward()#反向传播
  62. opt.step()#更新每一个参数
  63. l_his.append(loss.data.numpy())
  64. #可视化
  65. lables=["SGD","Momentum","RMSprop","Adam"]
  66. for i,l_his in enumerate(loss_his):#enumerate是列举,会迭代列表的中的每一个索引和每一项的值
  67. plt.plot(l_his,label=lables[i])
  68. plt.legend(loc=1)#legend是做一个图例说明 loc=1表示放在右边 详情看参数,label=lables[i]相对应
  69. plt.xlabel("steps")
  70. plt.ylabel("loss")
  71. plt.ylim((0,0.5))
  72. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/921840
推荐阅读
相关标签
  

闽ICP备14008679号