赞
踩
- import torch
- import torch.utils.data as Data
- import torch.nn.functional as f
- import matplotlib.pyplot as plt
-
- #指定超参数
- LR=0.01#学习率
- BATCH_SIZE=32#批数据的大小
- EPOCH=12#迭代次数
-
- #构造数据集
- x=torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
- y=x.pow(2)+0.1*(torch.normal(torch.zeros(*x.size())))
-
-
- #打印数据
- plt.scatter(x.data.numpy(),y.data.numpy(),c='r')
- plt.show()
-
- #使用dataloader工具进行数据的处理
- torch_dataset=Data.TensorDataset(x,y)#将x和y转换成torch可识别的数据集
- loader=Data.DataLoader(
- dataset=torch_dataset,
- batch_size=BATCH_SIZE,
- shuffle=True,
- num_workers=2
- )
-
- #构造网络结构并为每一个优化器优化一个神经网络
- class Net(torch.nn.Module):
- def __init__(self):
- super(Net,self).__init__()
- self.hidden=torch.nn.Linear(1,20)
- self.output=torch.nn.Linear(20,1)
- #前向传播
- def forward(self,x):
- x=f.relu(self.hidden(x))
- x=self.output(x)
- return x
-
- #每一个优化器对应一个网络结构
- net_SGD=Net()
- net_Momentum=Net()
- net_RMSprop=Net()
- net_Adam=Net()
-
- #放到一个列表中
- nets=[net_SGD,net_Momentum,net_RMSprop,net_Adam]
-
-
- #API化每一个优化器
- opt_SGD=torch.optim.SGD(net_SGD.parameters(),lr=LR)
- opt_Momentum=torch.optim.SGD(net_Momentum.parameters(),lr=LR,momentum=0.8)
- opt_RMSprop=torch.optim.RMSprop(net_RMSprop.parameters(),lr=LR,alpha=0.9)
- opt_Adam=torch.optim.Adam(net_Adam.parameters(),lr=LR,betas=(0.9,0.99))
-
- #用一个列表存放每一个优化器
- optimizers=[opt_SGD,opt_Momentum,opt_RMSprop,opt_Adam]
- #指定损失函数
- loss_func=torch.nn.MSELoss()
- #用一个两层列表记录各个优化器的loss
- loss_his=[[],[],[],[]]
-
- #训练 可视化
- for epoch in range(EPOCH):
- print(epoch)
- for step,(batch_x,batch_y) in enumerate(loader):
-
- #对于每一个优化器,优化他的神经网络
- for net,opt,l_his in zip(nets,optimizers,loss_his):
- output=net(batch_x)#对每一个网络丢入数据
- loss=loss_func(output,batch_y)#计算预测值和真实值之间的误差
- opt.zero_grad()#梯度清零
- loss.backward()#反向传播
- opt.step()#更新每一个参数
- l_his.append(loss.data.numpy())
-
- #可视化
- lables=["SGD","Momentum","RMSprop","Adam"]
- for i,l_his in enumerate(loss_his):#enumerate是列举,会迭代列表的中的每一个索引和每一项的值
- plt.plot(l_his,label=lables[i])
- plt.legend(loc=1)#legend是做一个图例说明 loc=1表示放在右边 详情看参数,label=lables[i]相对应
- plt.xlabel("steps")
- plt.ylabel("loss")
- plt.ylim((0,0.5))
- plt.show()

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。