当前位置:   article > 正文

gradient_accumulation_steps

gradient_accumulation_steps
  1. num_steps_all = len(train_loader) // configs.gradient_accumulation_steps * configs.epochs
  2. warmup_steps = int(num_steps_all * configs.warmup_proportion)
  3. scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=num_steps_all)

num_steps_all为更新次数,warmup_steps表示全部训练步骤的前10%,在这一阶段,学习率线性增加;此后,学习率线性衰减。

  1. for i,(images,target) in enumerate(train_loader):
  2. # 1. input output
  3. images = images.cuda(non_blocking=True)
  4. target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
  5. outputs = model(images)
  6. loss = criterion(outputs,target)
  7. # 2.1 loss regularization
  8. loss = loss/accumulation_steps
  9. # 2.2 back propagation
  10. loss.backward()
  11. # 3. update parameters of net
  12. if((i+1)%accumulation_steps)==0:
  13. # optimizer the net
  14. optimizer.step() # update parameters of net
  15. optimizer.zero_grad() # reset gradient
  1. 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
  2. loss.backward() 反向传播,计算当前梯度;
  3. 多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
  4. 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;

总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。

一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize '变相' 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。

 

gradient_accumulation_steps通过累计梯度来解决本地显存不足问题。
假设原来的batch_size=6,样本总量为24,gradient_accumulation_steps=2
那么参数更新次数=24/6=4
现在,减小batch_size=6/2=3,参数更新次数不变=24/3/2=4
在梯度反传时,每gradient_accumulation_steps次进行一次梯度更新,之前照常利用loss.backward()计算梯度。

 

如果一个模型是需要多卡并行训练以开大batchsize,而你没有这么多卡。那可以利用梯度累加的性质,在每次反向传播后,先不进行优化器的迭代,多累积几个batch的梯度后,再进行优化器迭代、梯度清零的操作。这样的话,即使使用单卡也可以达到多卡开大batch_size的效果哦~虽然训练会慢一点就是了,但是对卡的要求大大降低了。

作者:知乎用户
链接:https://www.zhihu.com/question/303070254/answer/647888393
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

 

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号