跳转至

知识蒸馏1

概要

一个简单的提升机器学习算法性能的方法是在相同的数据上训练多个不同的模型,然后平均它们的预测结果. 不幸的是,使用集成模型进行预测较为繁琐,并且可能计算成本过高,无法部署到大量用户,特别是当每个模型都是大型神经网络时. Caruana及其同事的研究2表明,可以将整个集成模型中的知识压缩到单个模型中,这种方法更容易部署. 研究人员在此基础上,进一步探索了不同的压缩方法. 他们在MNIST数据集上取得了令人惊讶的成果,并展示了可以通过知识蒸馏将一组模型的知识融合到单个模型中,从而显著提升某个常用商业系统的声学模型. 除此之外,他们还提出了一种新型的集成方法,该方法结合了一个或多个全功能模型(full model)和多个专用模型(specialist model). 首先使用全功能模型进行初步分类,当全功能模型对某些分类产生混淆时,则用专用模型进一步处理和区分. 与专家混合(mixture of experts)集成方法不同,他们提出的这种集成方法能够并行训练,大幅缩短整体训练时间.

背景

昆虫的不同形态的启发

许多昆虫的幼虫形态被优化以从环境中提取能量和营养, 但是它们有着完全不同的成虫形态以满足非常不同的旅行和繁殖需求. 然而, 在现在的大型机器学习中, 我们通常在训练和部署阶段使用非常相似的模型, 尽管这两个阶段的实际需求差异很大: 对于语音和物体识别等任务, 训练阶段必须从海量, 高度冗余的数据集中提取结构, 但是无需具有实时性, 并且可以使用大量的计算资源. 然而, 将它部署到大量用户设备的时候, 对延迟和计算资源的要求就严格许多.

对昆虫形态的类比给我们的启发是: 如果能够更容易地从数据中提取结构, 我们应该愿意训练非常庞大的模型. 这种庞大的模型可以是由单独训练的多个模型组成的集成, 或者是使用非常强的正则化方法, 如Dropout训练的单一大型模型. 一旦庞大模型训练完成, 我们就可以使用一种不同的训练方法, 称为"蒸馏", 将庞大模型中的知识转移到另一个更加适合部署的小模型中. 这种策略的一个版本已经由Caruana及其同事率先实现2, 在他们的论文中, 有效地证明了大型模型集成所获得的知识可以转移到单个小模型中.

知识是一种输入到输出的映射

一个可能阻碍进一步研究这一非常具有前景的方法的概念性障碍是, 当我们把一个训练好的模型的知识看作是其具体参数值的集合时, 就很难想象如何在改变模型结构(例如, 简化模型)的时候仍然保留这些知识, 这种观念限制了我们探索如何高效地将知识从一个模型转移到另一个模型的方法. 然而, 在更加抽象的层面上, 使其摆脱任何特定的实例, 知识可以被视为一种从输入向量到输出向量的学习映射.

错误类别概率差异很重要

对于那些需要区分大量类别的庞大模型, 通常的目标是最大化正确答案的平均对数概率, 然而, 在学习过程中, 训练好的模型不仅会给正确答案分配高概率, 还会为所有错误答案分配概率, 即使这些错误的概率很小, 某些错误的概率仍然可能比其他错误答案大得多. 这种错误答案的相对概率蕴含了许多关于模型泛化能力的信息, 例如, 对于一张BMW的照片, 它被误分类为垃圾卡车的概率虽然很小, 但是仍然比分类为胡罗卜的概率要高很多, 这表明模型能有有效地区分不同地类别, 并在区分类似类型(例如, 车辆中的不同类型)的时候, 仍然具有一定的辨别能力. 这说明, 模型对不同错误类别之间概率差异的度量承载了重要的模型泛化信息, 这种更加丰富的概率分布信息, 恰恰是进行知识提炼和知识迁移的时候非常具有价值的部分.

使用蒸馏间接提高泛化能力

在训练机器学习模型的时候, 目标函数(也称为损失函数)用于衡量模型在训练数据上的表现, 通常希望最小化损失函数. 但是用户真正关心的是模型在未见过的新数据(测试数据)上的表现, 即模型的泛化能力. 虽然理想情况下目标函数应该直接反映用户的真实目标(即泛化能力), 但是实际上很难设计出能够完全反映这一点的目标函数. 由于直接优化目标函数难以完美反映泛化能力, 知识蒸馏提供了一种间接提升模型泛化能力的方法. 具体来说, 知识蒸馏通过利用一个性能更强, 泛化能力更好的大型模型(教师模型)的知识, 来训练一个较小的模型(学生模型). 教师模型泛化能力好是因为它是一个多个大型不同模型集合的平均值. 通过继承泛化能力得到的学生模型和在与教师模型相同数据集上按照普通方式训练的小模型表现更优.

使用软目标提高泛化能力

教师模型在进行分类的时候, 会输出每个类别的概率分布, 而不是仅仅给出一个确定的类别标签, 这些概率分布被视为"软目标"(Soft Target), 也就是说, 这个软目标里面包含了错误类别概率. 在传统的训练中, 目标通常是一个明确的类别标签(如"一张图片属于猫"), 也叫做"硬目标". 为了将教师模型的泛化能力转移到学生模型, 可以使用教师模型生成的软目标作为学生模型的训练目标, 这种方法不仅仅让小型模型学习了正确的类别, 还让其了解了类别之间的相对关系和模型对各类别的不确定性. 这个过程就是使用蒸馏间接提高了泛化能力.

可以在与训练教师模型相同的数据上进行知识蒸馏, 也可以使用一个独立的"转移"数据集. 特别的, 如果教师模型实际上是由多个简单模型组成的集成模型的时候, 每个简单模型会对输入数据产生自己的预测分布, 可以将所有简单模型的预测分布取算术平均, 得到最终的软目标, 也可以采用几何平均的方法来融合各个模型的预测分布.

软目标通常具有更高的熵, 这意味着对于每个训练样本来说, 它们包含着更多的不确定性信息(类别之间), 而不是像硬目标那样仅仅是一个确定的类别标签, 所以可以在更少的数据上进行训练. 并且, 训练样本之间的梯度变化的方差显著减小, 这意味着参数更新的步伐非常均匀, 所以可以使用一个相对较大的学习率(方差较大的梯度变化需要用较小的学习率保证训练的稳定性).

使用温度解决置信度过高问题

对于像MNIST这样的任务, 教师模型几乎总是极高的置信度给出正确的答案, 学习到的函数的大部分信息都存在于软目标的非常小的比例中, 例如, 一个正确标签为2的样本有\(10^{-6}\)的可能性会产生3的输出, 有\(10^{-9}\)的概率会产生7的输出, 这种信息很具有价值, 正如错误类别概率差异很重要小节描述的那样, 它定义了数据上丰富的相似性结构(例如, 它指出哪些像2, 哪些像3, 哪些像7), 但是在转移阶段对学生模型的交叉熵成本函数的影响却非常小, 因为概率接近于0.

为了解决这个问题, Caruana及其同事选择使用logits(即最终softmax层的输入, ⚠️注意, logits可以是负值)而不是softmax产生的概率作为学生模型的学习目标, 并通过最小化教师模型和学生模型产生的logits之间的均方误差来进行训练. 作者则提出了一种更加通用的解决方案, 称为"蒸馏", 他们在softmax函数中引入一个温度参数, 通过提高温度(大于1), softmax输出的概率分布会变得更加"柔和"或者平滑, 这意味这类别之间的概率差异会减小, 提供更多关于各类别相对关系的信息, 也就是通过温度来降低置信度. 教师模型和学生模型需要使用相同的温度. 作者指出, Caruana等人直接使用logits作为目标其实是蒸馏的一种特殊情况, 当温度参数趋近于1的时候, 蒸馏方法退化为直接匹配logits, 因此, 蒸馏方法更加通用, 可以通过调整温度来控制学习目标.

硬目标辅助训练

用于训练学生模型的迁移集可以完全由未标记的数据构成, 也就是完全使用软目标, 或者直接使用原始数据集. 作者发现使用原始训练集的效果良好, 特别的, 他们往目标函数里面加了一个小项, 鼓励学生模型在匹配由教师模型提供的软目标的同时也鼓励学生模型预测真实目标. 这是因为, 学生模型无法完全匹配软目标, 它会出现一定程度的偏差, 如果这个偏差正好能让它更倾向于正确的类别, 那么从整体上看, 这有利于学生模型的最终效果.

方法

神经网络通常使用softmax来产生类别概率, 该层将为每个类别计算的logit值\(z_i\)转换为概率\(q_i\).

\[ q_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)} \]

其中, \(T\)为温度, \(T=1\)(Caruana等人), 使用更高的\(T\)值会在类别上产生更加柔和的概率分布. 如果将\(T\)设大, 则各个类别之间的概率分值差距会缩小, 也就是强化那些非最大类别的存在感. 反之, 则会加大类别间概率的两级分化. 当\(T\)趋近于\(0\)的时候, softmax输出将收敛为一个one-hot向量; 当\(T\)趋近于\(\infty\)的时候, \(e^{z_i}/T\rightarrow 1+z_i/T\)3.

在最简单的蒸馏方法中, 教师模型使用高温度的softmax为迁移集中的每个样本生成软目标. 随后, 学生模型采用相同的温度, 通过最小化其输出和这些软目标之间的均方误差来进行训练. 但是, 当学生模型训练结束即将部署的时候, 它的温度会被设置为1.

硬目标辅助训练

当迁移集中的所有或者部分样本标签已知的时候, 同时训练学生模型去输出正确的标签可以显著提高该方法的效果, 即硬目标辅助训练. 一种实现方式是利用正确标签来修改软目标, 但是作者发现更好的方式是对两个不同的目标函数使用加权平均, 第一个目标函数是和软目标之间的交叉熵, 其中在计算交叉熵的时候, 学生模型的softmax使用和教师模型生成软目标时相同的高温度; 第二个目标函数是和正确标签之间的交叉熵, 不过这个时候学生模型的softmax的温度被设置为1. 他们发现, 在第二个目标函数上使用更低的权重通常能够取得更好的结果.

由于软目标产生的梯度大小随着\(\frac{1}{T^2}\)缩放, 所以当同时使用硬目标和软目标的时候, 必须将软目标的梯度乘以\(T^2\). 这样, 在实验中改变温度的时候, 硬目标和软目标的对梯度的相对贡献就能基本保持不变.

梯度表示

在迁移集中, 每个样本都会对学生模型的每一个logit, \(z_i\), 由此产生软目标概率\(q_i\), 并贡献一个交叉熵梯度, \(dC/dz_i\), 教师模型有对应的logit, \(v_i\), 并由此产生软目标概率\(p_i\), 且在温度\(T\)下进行迁移训练, 那么, 该梯度可以被表示为:

\[\frac{\partial C}{\partial z_i} = \frac{1}{T}(q_i - p_i) = \frac{1}{T}\left(\frac{e^{z_i/T}}{\sum_j e^{z_j/T}} - \frac{e^{v_i/T}}{\sum_j e^{v_j/T}}\right)\]

高温的近似

如果温度较高的时候, 根据刚才的公式, \(T\rightarrow \infty,\ e^{z_i}/T\rightarrow 1+z_i/T\). 可以得到:

\[\frac{\partial C}{\partial z_i} \approx \frac{1}{T} \left( \frac{1 + z_i / T}{N + \sum_j z_j / T} - \frac{1 + v_i / T}{N + \sum_j v_j / T} \right)\]

假设我们已经对所有的logits进行零-均值化处理, 即\(\sum_j z_j=\sum_j v_j=0\), 上述公式能继续简化:

\[ \frac{\partial C}{\partial z_i} \approx \frac{1}{N T^2} \left(z_i - v_i\right) \]
为什么零-均值化不会影响结果

假设我们对所有的logits减去它们的均值: \(z_i'=z_i-\mu\), \(\mu=\frac{\sum z_i}{n}\), 其中\(n\)是类别的数量. \(\exp(z_i')=\exp(z_i-\mu)=\exp(z_i)\exp(-\mu)\). 零-均值化之后的softmax可以被表示为\(\frac{\exp(z_i')}{\sum_j \exp(z_j')}\), 由于\(\exp(-\mu)\)会在分子和分母中相互抵消, 因此前后softmax的输出是一样的.

温度的选择

由于\(N\), \(T\)是不变的, 所以在高温度, 所有logits进行零-均值化处理的情况下, 蒸馏相当于就是最小化\(\frac{1}{2}(z_i-v_i)^2\), 也就是说, 学生模型直接去拟合教师模型的logits, 不用拟合softmax之后的概率分布了. 而在温度较低的时候, 分布会变得非常"尖锐", 最大类别的存在感非常足, 极负的logits对最终的损失贡献很小, 因此学生模型会"忽略"这些极负的logits, 而更倾向于对较大的logits做出精确拟合, 这样的作法也有好处, 因为这些极负的logits对应的可能只是噪声, 不过, 这些极负的logits有时也可能蕴含教师模型"学到"的有价值信息, 如果完全忽略, 可能造成潜在的信息丢失. 最终是否要忽略这些极负的logits, 要看经验结果. 当学生模型参数量较小, 无法完整吸收教师模型所有知识的时候, 适当使用"中间温度"的效果通常最好, 也就是说, 既不过分忽略那些负logits, 又不让学生模型在噪声上过度拟合, 往往能取得很好的平衡.


  1. Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the knowledge in a neural network (No. arXiv:1503.02531). arXiv. https://doi.org/10.48550/arXiv.1503.02531 

  2. Buciluǎ, C., Caruana, R., & Niculescu-Mizil, A. (2006). Model compression. Proceedings of the 12th ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 535–541. https://doi.org/10.1145/1150402.1150464 

  3. 知识蒸馏—康行天下—博客园. (不详). 从 https://www.cnblogs.com/makefile/p/kd.html 

评论