当前位置:   article > 正文

人工智能-作业3:例题程序复现 PyTorch版_round(w.grad.item(), 2)

round(w.grad.item(), 2)

一、使用pytorch复现课上例题。

  1. import torch
  2. x1, x2 = torch.Tensor([0.5]), torch.Tensor([0.3])
  3. y1, y2 = torch.Tensor([0.23]), torch.Tensor([-0.07])
  4. print("=====输入值:x1, x2;真实输出值:y1, y2=====")
  5. print(x1, x2, y1, y2)
  6. w1, w2, w3, w4, w5, w6, w7, w8 = torch.Tensor([0.2]), torch.Tensor([-0.4]), torch.Tensor([0.5]), torch.Tensor(
  7. [0.6]), torch.Tensor([0.1]), torch.Tensor([-0.5]), torch.Tensor([-0.3]), torch.Tensor([0.8]) # 权重初始值
  8. w1.requires_grad = True
  9. w2.requires_grad = True
  10. w3.requires_grad = True
  11. w4.requires_grad = True
  12. w5.requires_grad = True
  13. w6.requires_grad = True
  14. w7.requires_grad = True
  15. w8.requires_grad = True
  16. def sigmoid(z):
  17. a = 1 / (1 + torch.exp(-z))
  18. return a
  19. def forward_propagate(x1, x2):
  20. in_h1 = w1 * x1 + w3 * x2
  21. out_h1 = sigmoid(in_h1) # out_h1 = torch.sigmoid(in_h1)
  22. in_h2 = w2 * x1 + w4 * x2
  23. out_h2 = sigmoid(in_h2) # out_h2 = torch.sigmoid(in_h2)
  24. in_o1 = w5 * out_h1 + w7 * out_h2
  25. out_o1 = sigmoid(in_o1) # out_o1 = torch.sigmoid(in_o1)
  26. in_o2 = w6 * out_h1 + w8 * out_h2
  27. out_o2 = sigmoid(in_o2) # out_o2 = torch.sigmoid(in_o2)
  28. print("正向计算:o1 ,o2")
  29. print(out_o1.data, out_o2.data)
  30. return out_o1, out_o2
  31. def loss_fuction(x1, x2, y1, y2): # 损失函数
  32. y1_pred, y2_pred = forward_propagate(x1, x2) # 前向传播
  33. loss = (1 / 2) * (y1_pred - y1) ** 2 + (1 / 2) * (y2_pred - y2) ** 2 # 考虑 : t.nn.MSELoss()
  34. print("损失函数(均方误差):", loss.item())
  35. return loss
  36. def update_w(w1, w2, w3, w4, w5, w6, w7, w8):
  37. # 步长
  38. step = 1
  39. w1.data = w1.data - step * w1.grad.data
  40. w2.data = w2.data - step * w2.grad.data
  41. w3.data = w3.data - step * w3.grad.data
  42. w4.data = w4.data - step * w4.grad.data
  43. w5.data = w5.data - step * w5.grad.data
  44. w6.data = w6.data - step * w6.grad.data
  45. w7.data = w7.data - step * w7.grad.data
  46. w8.data = w8.data - step * w8.grad.data
  47. w1.grad.data.zero_() # 注意:将w中所有梯度清零
  48. w2.grad.data.zero_()
  49. w3.grad.data.zero_()
  50. w4.grad.data.zero_()
  51. w5.grad.data.zero_()
  52. w6.grad.data.zero_()
  53. w7.grad.data.zero_()
  54. w8.grad.data.zero_()
  55. return w1, w2, w3, w4, w5, w6, w7, w8
  56. if __name__ == "__main__":
  57. print("=====更新前的权值=====")
  58. print(w1.data, w2.data, w3.data, w4.data, w5.data, w6.data, w7.data, w8.data)
  59. for i in range(1):
  60. print("=====第" + str(i) + "轮=====")
  61. L = loss_fuction(x1, x2, y1, y2) # 前向传播,求 Loss,构建计算图
  62. L.backward() # 自动求梯度,不需要人工编程实现。反向传播,求出计算图中所有梯度存入w中
  63. print("\tgrad W: ", round(w1.grad.item(), 2), round(w2.grad.item(), 2), round(w3.grad.item(), 2),
  64. round(w4.grad.item(), 2), round(w5.grad.item(), 2), round(w6.grad.item(), 2), round(w7.grad.item(), 2),
  65. round(w8.grad.item(), 2))
  66. w1, w2, w3, w4, w5, w6, w7, w8 = update_w(w
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/807862
推荐阅读
相关标签
  

闽ICP备14008679号