一、论文相关信息
1.论文题目
Distilling the Knowledge in a Neural Network
2.论文时间
3.论文文献
https://arxiv.org/abs/1503.02531
二、论文背景及简介
三、论文内容总结
-
Introduction总结
- 生物在生命的不同阶段具有不同的需求,也因此具有不同的结构与功能。机器学习也一样,在训练与部署阶段具有不同的需求,因此我们需要在不同的阶段做不同的事情。本文所提出的蒸馏,便是将训练阶段的模型通过变形使其适应部署阶段的需求。
- Caruana等人提出,可以将一个大的冗杂的模型的知识转移到一个小的更适合部署的模型中去,这种方法就叫做 蒸馏(“distillation”)
- 我们很难去评价一个模型是否成功的转化到了另一个模型,特别是对于该模型的泛化能力来说,因为我们并没有这样的先验知识。而将其泛化能力进行转化,就是我们要讲到的重点,对对抗样本的防御方法。
- 将大模型的泛化能力转化为小模型的一个方法是利 用大模型产生的类概率作为训练小模型的“soft targets”(就是将其label转换成大模型生成的类概率,进行训练)。
- 在Caruana等人做蒸馏时,出现了一些问题(见Intorduction第5段)。作者为了解决这个问题,提出,增加softmax的温度,直到其soft target比较合适位置。
- 我们对小模型进行训练时,使用的是 tranfer集 ,该数据集可以是原本的数据集,也可以是一些不带标签的数据组成的
-
Distillation方法
- 利用大模型产生的类概率作为训练小模型的“soft targets”(就是将其label转换成大模型生成的类概率,进行训练)
- 使用大模型产生的logits作为目标,是蒸馏的一个特例
- 作者在MINST数据集和语音识别任务上都进行了实验
-
作者提出了一个
训练大型集成模型
的方法
- 集成模型包括一个通用模型和多个”specialist“模型,每一个”specialist“模型专注于一部分易于区分的类别,把剩下的类别作为垃圾类。采用了一个独特的推断方式进行预测。
- 使用soft targets进行训练时,可以提高模型的泛化能力,防止过拟合,而且只用一小部分训练集就可以训练的很好。
- 第一步,我们首先根据通用模型找到$ n$ 个最可能的类别,我们把它称之为类别集合$ k$ 。作者取n=1
- 第二步,我们找到所有的“specialist”模型,这些模型的训练集的混淆类别$ S^m$ 与$ k$ 相交不为空。把这些模型集合称之为$ A_k$ ,我们的目标就是找到一个概率分布$ q$ ,使得下面的式子最小:
四、论文主要内容
1、Introduction
一个概念块可能阻止了对这种非常有前途的方法的更多研究,那就是我们倾向于用学习到的参数值来识别训练模型中的知识,这使得我们很难看到如何改变模型的形式,但保持相同的知识。将知识从任何特定实例化中解放出来的更抽象的知识视图是,它是从输入向量到输出向量的学习映射。对于学习区分大量类的大模型,通常的训练目标是使正确答案的平均对数概率最大化,但学习的副作用是训练模型为所有错误答案分配概率,即使这些概率非常小,也有一些比其他的大很多。 不正确答案的相对概率告诉我们很多关于大模型如何趋向于泛化的信息 。例如,一辆宝马车的图片,被误认为是垃圾车的可能性很小,但这个错误的可能性仍然比误认为是胡萝卜的可能性大很多倍。
众所周知,目标函数应该尽可能反映用户的真实目标。比如, 当我们的目标函数目的是为了使模型更好的泛化到新数据时,我们就能够训练出一个泛化能力好的模型 ,但是,这需要我们对泛化要有足够的认识、信息。而我们目前并没有这些信息。 当我们将大模型蒸馏到小模型时,我们可以训练这个小模型使其具有与大模型一样的泛化能力 。而且,我们用训练大模型的方式去训练小模型,其小模型的泛化能力通常比用正常训练方式得到的小模型的泛化能力要好得多。
将大模型的泛化能力转化为小模型的一个方法是利 用大模型产生的类概率作为训练小模型的“soft targets”(就是将其label转换成大模型生成的类概率,进行训练)。 在这个转移阶段,我们可以使用同样的训练集或者使用一个单独 “transfer”数据集 。当 大模型是集成模型时,我们根据每个模型的贡献率计算类概率的算数均值或者几何均值作为”soft targets” 。当”soft targets”具有很高的熵时,当对于”hard targets”,其在训练的情况下,能够提供更多的信息量,且梯度变化也要小得多。因此 小模型往往比原始的大模型能够在更少的数据集上及逆行训练,并且可以使用更高的学习率。
对于MNIST这样的任务来说,大模型总能够以很高的置信度生成正确的答案。关于所学习函数的大部分信息都存在于软目标中非常小的概率比中。比如:一张2的图片,在一个版本中,可能给予类别3 10^-6的概率,给与类别7 10^-9的概率,但是在另一个训练版本中,可能是相反的。这是有价值的信息,定义了数据上丰富的相似结构,但是这对于交叉熵损失函数来说却有着很小的影响,因为对交叉上来说,概率太小了,接近0了。Carunan等人通过使用logits损失函数规避了这个问题。他们最小化 大模型生成的logits和小模型生成的logits的平方差。我们更普遍的解决方案,称为“ 蒸馏 ”,是 提高最终softmax的温度(下面会讲到),直到大模型产生一组合适的软目标。然后我们在训练小模型时使用相同的高温来匹配这些软目标 。我们稍后将展示,匹配大模型的logits实际上是蒸馏的一个特例。
“transfer”集,可能是由没有标签的数据组成的,也可能时使用原始训练集。我们发现使用原始训练集效果很好,特别是如果我们在目标函数中加入一个小项,鼓励小模型预测真实目标,以及匹配笨重模型提供的软目标。通常,小模型不能精确地匹配软目标,朝着正确答案的方向出错是有帮助的。
2、Distillation
神经网络通常使用softmax来生成类别概率,其将logits(神经网络输出的值)$ z_i$ 转换成概率$ q_i$ :
T表示的softmax的温度,通常设为1,T越大,则softmax就能够生成更软的概率(更软的意思我在AI小知识系列第一讲中讲到)
蒸馏网络就是我们要转移到的那个小模型
在最简单的蒸馏形式中,通过 使用transfer集和每个类别的软目标分布 (使用 带有高温的softmax的大模型 生成的,就是指上面提到的各个类别的soft target的分布), 对蒸馏网络进行训练从而将知识传递到蒸馏网络中 。训练蒸馏模型时使用相同的高温,但训练后,进行部署时,使用的温度为1。
当所有或部分transfer集都有正确的标签时,上面的方法还可以通过训练蒸馏模型来生成正确的标签来得到显著改进。而让所有或者部分tranfer集有正确的标签的一种方法,是使用正确的标签(原本的离散的标签)来调整soft targets,但是我们发现了一种更好的方式,就是简单的 使用两种不同目标函数的加权平均值 。 第一个目标函数是带有soft targets的交叉熵 ,该交叉熵使用跟蒸馏网络相同的高温来进行计算,其用于在大模型中生成软目标。 第二个目标函数是带有正确标签(hard targets)的交叉熵 ,该交叉熵使用的温度为1,其使用的是蒸馏网络中softmax的logits。作者发现,当第二个目标函数具有相对较低的权重时,能够得到最好的结果。因为,通过soft targets训练而得到的网络的梯度的大小大概在$ 1/T^2$ ,因此,我们要把hard targets 与 soft targets都乘以$ T^2$ 。当蒸馏网络的温度改变(调参)时,这保证了hard targets 与 soft targets 的相对分布不会发生改变。
2.1 匹配logits是蒸馏的一个特例
如果T比logits大,那么:
如果我们假设在transfer集上的所有logits的均值为0,即:$ \sum_j z_j = \sum_j v_j=0$ ,则:
因此,在高的T值以及logits的0均值的限制下,蒸馏等价于最小化$ \frac{1}{2}(z_i-v_i)^2$ 。而在低的温度下,蒸馏则不会特别的注重于匹配比平均值负得多logits,这是蒸馏的优势,因为被损失函数用作训练大模型的logits几乎没有限制条件,所以可能会存在很大的噪声。另一方面,非常负的logits,可能会传递由大模型产生的有用的信息。这些影响中哪一个占主导地位是一个经验问题。我们表明,当蒸馏模型太小而无法捕获大模型中的所有知识时,中间温度的效果最好,这强烈表明忽略较大的负对数是有帮助的。
3、在MNIST数据集上进行初步实验
当蒸馏网络,每层隐藏层有300+神经元时,将温度设置为8+,有着相当相似的结果。但是,当每层隐藏层的神经元降到每层30个的时候,将温度设置为2.5~4,得到的效果比4+或者2.5-的温度要好。
之后作者尝试将数字3从transfer集中删除,对蒸馏网络来说,3是个神秘的数字,因为训练集里没有3。这样试验后,蒸馏网络将206张图片分类错误(133张为数字3,测试集上总共1010张3)。大部分分类错误的原因是,给类别3 分的权重太低了,如果将这个权重乘以3.5,那么就只有109张错误的,其中,只有14张是数字3。所以,当有着正确的权重时,蒸馏网络能够得到98.6%的对3的正确率。如果我们继续增大权重,增加到7倍或者8倍,那么蒸馏网络的错误率就上升到了47.3%。
4、在语音识别上进行试验
作者在语音识别任务上进行了实验,使用的ASR模型(Automatic Speech Recognition),发现使用蒸馏策略得到的模型要比直接训练得到的相同大小的模型的效果好得多。
5、在大数据集上进行训练出集成模型
在这一节,我们给除了这样的例子,我们展示了,如何学习到这样一个集成模型,在集成模型中,每个单独的模型都关注一个不同的子集,这样就减少了计算量,但这样也带来了问题,就是很容易过拟合。而这个过拟合可以通过使用soft targets来解决。
5.1 JFT数据集
很明显,用几个月的时间来训练一个模型并不是一个好的选择,所以我们需要找到一个更快的训练baseline的方法。
5.2 Specialist Models
为了减少过拟合,并且让这些“specialist”的模型能够共享一些低阶特征,每 一个“specialist”模型都会用通用模型的权重来进行初始化 ,然后在他们特别的训练集上进行fine-tuning。这些特别的训练集是这样得到的,他们一半来自于特殊的子集,一半来自于训练集的补集(lable为垃圾类别)。在训练结束后,我们可以通过 将垃圾类别的logits乘以抽样比的对数来修正 。
5.3 将类别分配给模型
我们将聚类算法应用于我们的广义模型预测的协方差矩阵,一组经常一起预测的类$ S^m$ 将被用作我们的一个“specialist”模型m的目标,我们对协方差矩阵的列应用了在线版本的K-均值算法,并获得了合理的簇。
5.4 集成模型的推断
上面所讲到的集成模型的预测过程如下,给定一个输入图片$ x$ :
其中$ KL$ 表示$ KL$ 散度,$ p^g$ 表示通用模型的类别概率,$ p^m$ 表示第$ m$ 个“specialist”模型得到的类别概率(包括一个垃圾类别)
我们会参数化$ q=softmax(z)(T=1)$ ,然后使用梯度下降来优化logits$ z$ 。
5.5 结果
在这次实验过程中,使用了61个”specialist“模型,每一个模型包括300个类别(包括垃圾类别)。作者发现,当一个类别被更多的”specialist“模型覆盖时,它预测的准确率就提升的越高。
6、Soft Targets as Regularizers
6.1 利用soft targets来防止过拟合