当前位置:   article > 正文

pytorch训练时显存溢出_pytorch训练几次后爆显存

pytorch训练几次后爆显存

网络在前期可以正常训练,但训练几轮后就发生显存爆炸的问题,调整输入大小或者每次循环都清除显存 也无法解决问题,后来经过查询,是在对loss求和时,直接使用

tl += loss

可以看到,loss是张量,经过运算后,tl也是张量,在神经网络中,pytorch会默认将张量操作放到计算图中,随着训练次数的增加,计算图会越来越大,直至显存爆炸。

解决办法:

tl += loss.item()

计算图原理:

        计算图中每个节点代表一个输入,每条边代表一个运算操作。例如y=(a+b)(b+c),则a,b,c都是节点,之后a+b在连接到一个节点,b+c连接到一个节点,最后两节点连接输出。

pytorch是动态建立计算图,边建立边计算。

tensorflow是静态建立计算图

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/article/detail/51860
推荐阅读
相关标签
  

闽ICP备14008679号