当前位置:   article > 正文

pytorch 笔记:叶子张量,retain_grad,register_hook_torch retain_grad

torch retain_grad

1 概念介绍

        在pytorch的tensor类中,有个is_leaf的属性,表示这个tensor是否是叶子节点:is_leaf 为False的时候,则不是叶子节点, is_leaf为True的时候为叶子节点(有的地方也叫做叶子张量)

1.1 为什么要叶子节点?

        对于tensor中的 requires_grad()属性,当requires_grad()为True时我们将会记录tensor的运算过程并为自动求导做准备。

        但是并不是每个requires_grad()设为True的值都会在backward的时候得到相应的grad。它还必须为leaf。

        这就说明. leaf成为了在 requires_grad()下判断是否需要保留 grad的前提条件

  • 提出叶子张量的原因是为了节省内存/显存
    • 那些非叶子结点是通过用户所定义的叶子节点的一系列运算生成的(也即中间变量)
    • 一般情况下,用户不会去使用这些中间变量的导数,所以为了节省内存,它们在用完之后就被释放了

1.2  哪些张量是叶子张量?

  • 所有requires_grad为False的张量(Tensor) 都为叶张量( leaf Tensor)
  1. x_=torch.arange(10,dtype=torch.float32).reshape(10,1)
  2. x_.is_leaf
  3. #True

  • requires_grad为True的张量(Tensor),如果他们是由用户创建的,则它们是叶张量(leaf Tensor).这意味着它们不是运算的结果,因此gra_fn为None
  1. xx=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(10,1)
  2. xx.is_leaf
  3. #False
  1. xx=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(10,1)
  2. ww=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(1,10)
  3. yy=ww@xx
  4. yy.backward()
  5. xx.grad,ww.grad
  6. #(None, None)
  7. '''
  8. UserWarning: The .grad attribute of a Tensor that is not a leaf
  9. Tensor is being accessed. Its .grad attribute won't be populated
  10. during autograd.backward(). If you indeed want the gradient for
  11. a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor.
  12. If you access the non-leaf Tensor by mistake, make sure you
  13. access the leaf Tensor instead.
  14. See github.com/pytorch/pytorch/pull/30531
  15. for more informations.
  16. '''

  • 只有是叶张量的tensor在反向传播时才会将本身的grad传入的backward的运算中.。如果想得到当前自己创建的,requires_grad为True的tensor在反向传播时的grad, 可以用retain_grad()这个属性
  1. xx=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(10,1)
  2. ww=torch.arange(10,dtype=torch.float32,requires_grad=True).reshape(1,10)
  3. yy=ww@xx
  4. xx.retain_grad()
  5. ww.retain_grad()
  6. yy.backward()
  7. xx.grad,ww.grad
  8. '''
  9. (tensor([[0.],
  10. [1.],
  11. [2.],
  12. [3.],
  13. [4.],
  14. [5.],
  15. [6.],
  16. [7.],
  17. [8.],
  18. [9.]]),
  19. tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]]))
  20. '''

 2 保存中间变量的梯度

  • 如果需要保留中间变量的导数,那么可以使用tensor.retain_grad()
    • 哪一个张量需要保存,哪一个张量加上retain_grad()
  1. loss = l4.mean()
  2. l4.retain_grad()
  3. loss.backward()
  4. print(l4.grad)

3 输出中间变量的梯度

 如果我们只是想进行 debug,只需要输出中间变量的导数信息,而不需要保存它们,我们还可以使用 tensor.register_hook

  1. loss = l4.mean()
  2. l4.register_hook(lambda grad: print('l4 grad:', grad))
  3. loss.backward()

参考内容:【one way的pytorch学习笔记】(三)leaf 叶子(张量)_One Way的博客-CSDN博客

PyTorch 之Autograd 详解 (qq.com) 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/340928
推荐阅读
相关标签
  

闽ICP备14008679号