Appearance
引言
DeepSeek 推理模型的爆火,也成了茶余饭后的话题。前段时间和同事一起吃饭便聊到了 LLM 知识蒸馏技术,但我们都是非专业背景下对这个知识蒸馏技术其实挺陌生,通过学习并记录下来。
蒸馏和微调的区别
微调(Fine-tuning)
微调是指在预训练模型的基础上,通过少量数据进一步训练,以适应特定任务。具体步骤如下:
- 预训练模型:在大规模数据集上训练好的模型。
- 任务特定数据:针对特定任务的小规模数据集。
- 调整参数:在预训练模型的基础上,使用任务特定数据进行少量训练,调整模型参数。
微调的目的是利用预训练模型的通用特征,快速适应新任务。
蒸馏(Distillation)
蒸馏是一种模型压缩技术,用于将大型模型(教师模型)的知识转移到小型模型(学生模型)上。具体步骤如下:
- 教师模型:一个复杂且性能较好的模型。
- 学生模型:一个更小、更简单的模型。
- 知识转移:通过让学生模型模仿教师模型的输出(如软标签),实现知识传递。
蒸馏的目的是在保持较高性能的同时,减少模型的计算和存储需求。
主要区别
- 目的:蒸馏用于模型压缩,同时保留较高的性能;微调用于适应特定任务需求。
- 方法:蒸馏通过知识转移,在优化学生模型时是不需要引入新的参数的;微调一般是有监督学习,微调通过使用带标签的数据调整预训练模型,使其更好地适应特定任务或数据集(例如知识库QA)。
- 应用场景:蒸馏适用于资源受限的环境,微调适用于特定任务。
什么是蒸馏
核心思想
知识蒸馏的目标是将一个复杂的、性能较好的模型(教师模型)的知识转移到一个更小、更简单的模型(学生模型)中。这里的"知识"指的是教师模型对输入数据的输出分布(即软标签
,Soft Labels),而不仅仅是硬标签
(Hard Labels,如分类任务中的类别标签)。
什么是标签
1. 硬标签(Hard Labels)
- 硬标签是传统监督学习中的标准标签形式,通常为离散的、确定性的类别标签。
- 在分类任务中,硬标签一般以 One-Hot编码 表示。例如,在图片分类任务中,例如识别图片中的动物是猫、狗、鸟,真实识别为猫,则硬标签为
[1, 0, 0]
。
特点
- 确定性:明确指定唯一正确的类别。
- 信息量少:仅保留最终分类结果,不反映模型对不同类别的置信度或类别间的关系。
- 训练目标直接:直接优化模型输出与真实标签的匹配度(如
交叉熵损失
)。
2. 软标签(Soft Labels)
- 软标签是教师模型对输入数据预测的概率分布,通过 Softmax 函数生成。
- 例如,教师模型对识别图片中的动物是猫、狗、鸟概率分别为
[0.6, 0.3, 0.1]
(对每个动物的置信度)。
特点
- 概率性:反映教师模型对各个类别的置信度,包含更多信息(如类别间相似性)。
- 知识丰富性:包含教师模型的泛化能力和对"模糊样本"的决策逻辑(例如,某样本属于猫的概率为 0.6,狗的概率为 0.3,鸟的概率为 0.1)。
模型输出的概率是什么意思
例如上面描述的
猫、狗、鸟概率分别为
[0.6, 0.3, 0.1]
这个输出是一个概率分布,表示模型对输入数据的预测结果。具体含义取决于任务类型:
如果是分类任务:
假设任务是将输入分为 3 个类别(例如猫、狗、鸟)。
[0.6, 0.3, 0.1]
表示模型认为输入属于:- 类别 1 的概率是 60%(例如猫)。
- 类别 2 的概率是 30%(例如狗)。
- 类别 3 的概率是 10%(例如鸟)。
如果是文生文任务(如我们经常使用的 chatgpt
对话生成):
文生文任务是一个序列生成任务,模型需要逐词(或逐 token)生成输出。
在每一步生成时,模型会输出一个概率分布,表示下一个词(或 token)的可能性。
例如,假设词汇表有 3 个词("你好"、"再见"、"谢谢"),模型输出的
[0.756, 0.178, 0.066]
表示:- 下一个词是"你好"的概率是 75.6%。
- 下一个词是"再见"的概率是 17.8%。
- 下一个词是"谢谢"的概率是 6.6%。
知识蒸馏的大致步骤
1. 训练教师模型
- 教师模型是一个大型的、复杂的模型(例如 DeepSeek-R1)。
- 在训练集上训练教师模型,直到其达到较高的性能。
- 教师模型的输出不仅包括最终的预测类别(硬标签),还包括每个类别的概率分布(
软标签
)。
2. 生成软标签
- 对于每个输入数据,教师模型会输出一个概率分布(软标签)。例如,在识别图片是猫、狗、鸟的分类任务中,教师模型可能会输出:
python
[0.6, 0.3, 0.1]
这表示模型认为输入的图片 60% 的概率属于猫,30% 的概率属于狗,10% 的概率属于鸟。
- 软标签包含了更多的信息,比如类别之间的相对关系(例如:猫 和 狗 更相似)。
3. 温度参数(Temperature)
- 在生成软标签时,会引入一个**温度参数(Temperature, T)**来平滑概率分布。
- 温度参数的作用是调整教师模型输出的软标签的"软度"。较高的温度会使概率分布更加平滑("模糊"、"均匀"),类别之间的概率差距缩小,模型能够看到更多类别的"相似性",从而让学生模型更容易学习到类别之间的关系。较低的温度则正好相反。(ps: 较小的温度会使得模型的输出更加稳定)
- 温度参数的公式如下:
其中,zi 是教师模型输出的 logits,T 是温度参数。
当 T=1 时,Softmax 的输出是原始的概率分布。
当 T>1 时,Softmax 的输出会更加平滑,概率分布会更加均匀。
当 T<1 时,Softmax 的输出会更加尖锐,概率分布会更加集中。
4.计算损失函数
在知识蒸馏中,学生模型的训练通常结合了两部分损失:
- 蒸馏损失(Distillation Loss)
用于衡量学生模型输出与教师模型软标签之间的差异。常见的方法是使用KL 散度
(Kullback-Leibler Divergence): - 监督损失(Supervised Loss)
当训练数据包含真实标签时,可以同时使用交叉熵
损失来直接监督学生模型,使其输出接近真实标签(硬标签):
ps: 这里不做太多的介绍了,一方面篇幅有限,另方面本身是小白也在学习中...
5. 训练学生模型
学生模型是一个更小、更简单的模型(例如 DeepSeek-R1-32B),学生模型的目标是模仿教师模型的输出(
软标签
),而不是直接学习硬标签
。
- 通过反向
传播算法
,计算损失函数对学生模型参数的梯度。 - 使用梯度下降法更新学生模型的参数,使得损失函数逐渐减小。
- 经过多次迭代,学生模型的输出分布会越来越接近教师模型的输出分布。
反向传播算法(Backpropagation)是训练神经网络的核心算法之一。它的作用是计算损失函数对模型参数的梯度,然后利用这些梯度来更新模型的参数。
反向传播的步骤:
前向传播(Forward Pass):
- 输入数据通过学生模型,计算模型的输出。
- 例如,输入一张猫的图片,学生模型输出
[0.756, 0.178, 0.066]
(表示属于猫、狗、鸟的概率)。计算损失(Loss Calculation):
- 将学生模型的输出与目标(教师模型的软标签
[0.622, 0.245, 0.133]
)进行比较,计算损失值(如 KL 散度)。反向传播(Backward Pass):
- 从损失函数开始,沿着计算图(Computation Graph)反向传播,计算损失函数对每个参数的梯度。
- 梯度表示的是:如果稍微调整某个参数,损失函数会如何变化。
更新参数(Parameter Update):
- 使用梯度下降法(Gradient Descent)或其他优化算法,根据梯度更新模型的参数。
知识蒸馏的示例
假设我们有一个图片识别分类任务,类别数为3,分别是猫、狗、鸟的概率。
1.教师模型的输出(未经过温度调整):
python
logits = [5.0, 2.0, 1.0]
Logits 是模型输出的原始分数,表明三个类别的得分情况(此时不具备概率的属性)
Softmax 公式:
Softmax 是一种数学函数,用于多分类任务中。它的作用是将一组数值(模型输出的 logits)转换为概率分布。具体来说:
- 输入:一组数值(logits),例如
[5.0, 2.0, 1.0]
- 输出:一组概率值,所有值的和为 1,例如
[0.936, 0.047, 0.017]
经过 Softmax 计算将 logits 转化为概率分布后:
python
softmax(logits) = [0.936, 0.047, 0.017]
这表明教师模型输出概率分别为:
猫的概率为: 0.936
狗的概率为: 0.047
鸟的概率为: 0.017
2.引入温度参数(T=2):
引入温度后 Softmax 公式:
使用温度参数计算后的概率分布:
python
softmax(logits / T) = [0.736, 0.164, 0.100]
可以看到,温度参数使概率分布更加平滑。
3.学生模型的目标:
通过计算损失函数对学生模型参数的梯度,更新学生模型已有的参数(如权重和偏置),通过不断更新参数
让学生模型的输出应该尽量接近 [0.736, 0.164, 0.100]
这样学生模型就习得了教师模型的性能
总结
知识蒸馏技术作为一种有效的模型压缩方法,能够在保留较高性能的同时,大幅减少模型的计算和存储需求。这使得蒸馏在资源受限的环境中,尤其是在边缘设备和移动端应用中,具有重要的应用价值。
像 DeepSeek-R1 的满血版蒸馏出 32B、14B 的小模型一样,让个人电脑上运行接近大模型的性能成为了一种可能。
其他
- 蒸馏是一种模型压缩技术,常见的模型压缩技术还有 量化、剪枝等
- 这里未对学生模型中涉及如何反向传播计算梯度以及损失函数如何预测效果等,后面继续更新。