赞
踩
注:文章来自于我的博客shawnluo.com,欢迎访问~!
作者:Geoffrey Hinton,Oriol Vinyals,Jeff Dean
发表信息:Machine Learning (cs.LG); Neural and Evolutionary Computing (cs.NE)
神经网络训练阶段从大量数据中获取网络模型,训练阶段可以利用大量的计算资源且不需要实时响应。然而到达使用阶段,神经网络需要面临更加严格的要求包括计算资源限制,计算速度要求等等。
一个复杂的网络结构模型是若干个单独模型组成的集合,或者是一些很强的约束条件下(比如dropout率很高)训练得到的一个很大的网络模型。一旦复杂网络模型训练完成,我们便可以用另一种训练方法:“蒸馏”,把我们需要配置在应用端的缩小模型从复杂模型中提取出来。“蒸馏”的难点在于如何缩减网络结构但是把网络中的知识保留下来。知识就是一幅将输入向量导引至输出向量的地图。做复杂网络的训练时,目标是将正确答案的概率最大化,但这引入了一个副作用:这种网络为所有错误答案分配了概率,即使这些概率非常小。我们将复杂模型转化为小模型时需要注意保留模型的泛化能力,一种方法是利用由复杂模型产生的分类概率作为“软目标”来训练小模型。在转化阶段,我们可以用同样的训练集或者是另外的“转化”训练集。当复杂模型是由简单模型复合而成时,我们可以用各自的概率分布的代数或者几何平均数作为“软目标”。当“软目标的”熵值较高时,相对“硬目标”,它每次训练可以提供更多的信息和更小的梯度方差,因此小模型可以用更少的数据和更高的学习率进行训练。
蒸馏大致描述如下图:
cumbersome model表示复杂的大模型,distilled model表示经过knowledge distillation后学习得到的小模型,hard targets表示输入数据所对应的label ,例如[0,0,1,0]。soft targets表示输入数据通过大模型(cumbersome model)所得到的softmax层的输出,例如[0.01,0.02,0.98,0.17]。
Softmax公式:
qi 表示第 i 类的输出概率,zi、zj 表示 softmax 层的输入(即 logits),T 为温度系数,用来控制输出概率的soft程度。
论文方法的关键之处便是利用soft target来辅助hard target一起训练。
由于hard target 包含的信息量(信息熵)很低,而soft target包含的信息量大,拥有不同类之间关系的信息。比如同时分类驴和马的时候,尽管某张图片是马,但是soft target就不会像hard target 那样只有马的index处的值为1,其余为0,而是在驴的部分也会有概率。
这样做的好处是,这个图像可能更像驴,而不会去像汽车或者狗之类的,而这样的soft信息存在于概率中,以及label之间的高低相似性都存在于soft target中。但是如果soft targe是像这样的信息[0.98 0.01 0.01],就意义不大了,所以需要在softmax中增加温度参数T(这个设置在最终训练完之后的推理中是不需要的)
T 的意义可以用如下图 来理解,图中 红,绿,蓝 分别对用同一组z在T为(5,25,50)下的值,可以看出,T越大,值之间的差就越小(折线更平缓,即更加的 soft),但是相对的大小关系依然没变。
目标函数由以下两项的加权平均组成:
soft targets 和小模型的输出数据的交叉熵(保证小模型和大模型的结果尽可能一致)
hard targets 和小模型的输出数据的交叉熵(保证小模型的结果和实际类别标签尽可能一致)
算法示意图:
1、训练大模型:先用hard target,也就是正常的label训练大模型。
2、计算soft target:利用训练好的大模型来计算soft target。也就是大模型“软化后”再经过softmax的output。
3、训练小模型,在小模型的基础上再加一个额外的soft target的loss function,通过λ来调节两个loss functions的比重。
4、预测时,将训练好的小模型按常规方式(如上右图)使用。
利用大模型提取先验知识,将这种先验知识作为soft target让小模型学习
1、初步试验 Mnist数据集
训练一个有两层具有1200个单元的隐藏层的大型网络(使用dropout和weight-constraints作为正则)值得注意的一点是dropout可以看做是share weights 的ensemble models;
另外一个小一点的网络具有两层800个单元隐藏层没有正则
训练结果:第一个网络test error 67个,第二个是146个;再加入soft target并且T设置为20之后小型网络test error达到74个
2、在语音识别数据上的实验
3、在大规模数据集上的实验
注:文章来自于我的博客shawnluo.com,欢迎访问~!
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。