赞
踩
本文是Hinton大神在网络压缩与迁移学习方向挖的一个坑
原文链接Distilling the Knowledge in a Neural Network
这种方法感觉受到了ensemble的启发,利用大型(teacher net)网络提取先验知识,将这种先验知识作为soft target让微型网络(student network)学习,有点像Boost中第一个分类器学到后调整weight让第二个分类器学习。当然相似中也有不同之处
主旨就是从复杂网络(teacher net)中抽取训练数据的分布‘教给’简易网络(student net)
- 复杂网络可以从
- ensmeble models 中学习
- 从单独的大型网络(使用正则项或dropout)中学习
复杂网络通常通过最大化log probability去学习多分类,这有个副作用是模型通常会赋予非正确答案一定的概率,即使这些概率都很小但是有一些是明显大于其他的(类似于softmax)。
在transfer的过程中(即cumbersome network转向small network的时候)可以使用同样的训练集也可以使用单独的训练集(猜想:可能使用不同的训练集的时候效果会更好;但是文中说使用原始数据集效果更好)
值得注意的是 对于Mnist数据集中总是产生具有信心的正确结果,很多非正确数字的概率都是非常小的,举个栗子,对于正确的2来说,被分类为3的概率为
10−6 被分类为7的概率为10−9 。在这种情况下soft target的帮助就很小,所以Caruana在其文章中使用log函数解决这一问题。本文利用蒸馏解决这个问题(蒸馏会使得target变得更加soft)
加入蒸馏后的softmax函数:
这里T是超参数,文中说是‘温度’,经过该参数之后的softmax会更加平滑,分布更加均匀而大小关系不变。T参数在设置为1的时候就是平常的softmax函数。
在知识转换阶段,设置复杂网络与简易网络相同的T参数。在此之后再从新将T设置为1
根据这篇博客再加入T之后的softmax的概率分布更加平滑,作为soft target时简易(student)网络能学到更多东西
另外使用matlab做蒸馏
在hinton的slides中也可以看到更加平滑的概率slides
具体蒸馏结构如下图所示:
这里
1、训练大模型:先用hard target,也就是正常的label训练大模型。
2、计算soft target:利用训练好的大模型来计算soft target。也就是大模型“软化后”再经过softmax的output。
3、训练小模型,在小模型的基础上再加一个额外的soft target的loss function,通过lambda来调节两个loss functions的比重。
4、预测时,将训练好的小模型按常规方式(右图)使用。
这里参考
将hard target的参数设置的比较小可以得到较好的结果,在做梯度下降时需要将soft target的梯度乘以1/
T2 ,这是因为将梯度的规模保持与hard target一样
首先给出交叉上的求导公式
∂C∂zi=1T(qi−pi)=1T(eziT∑jezjT)−1T(eviT∑jevjT)(2)
其中vi 是cumbersome model产生的logitsvi 相当于真实分布,zi 是distilled model
公式推导:
qi=1T(eziT∑jezjT)(2.1)
pi=1T(eviT∑jevjT)(2.2)
cross entropy=−∑j pj×log qj(2.3)
对交叉熵求导即可得到公式(2)
并且作者给出当T非常大的时候distillation优化的目标等价于Caruana提取的对logits的平方误差求最优化
∂C∂zi≃1T(1+ziTN+∑jzjT)(1+viTN+∑jvjT)(3)
当我们假设logits是zero-means的则∑jzj=∑jvj=0
∂C∂zi≃1NT2(zi−vi)(4)
训练一个有两层具有1200个单元的隐藏层的大型网络(使用dropout和weight-constraints作为正则)值得注意的一点是dropout可以看做是share weights 的ensemble models
另外一个小一点的网络具有两层800个单元隐藏层没有正则
结果是第一个网络test error 67个,第二个是146个;再加入soft target并且T设置为20之后小型网络test error达到74个
另外需要注意一点的是:
When the distilled net had 300 or more units in each of its two hidden layers, all temperatures above 8 gave fairly similar results. But when this was radically reduced to 30 units per layer, temperatures in the range 2.5 to 4 worked significantly better than higher or lower temperatures.
该现象可能说明将概率设置的过于soften可能会导致一些问题尤其是在拟合能力较差的网络中
另外的重要发现
剩下的就是在语音数据及的实验以及大型数据及JFT的实验。具体可以参考原文。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。