深度学习进阶(六)归纳偏置与蒸馏
<p>在上一篇,我们已经完成了 Vision Transformer的完整逻辑:把图像切成 patch 当作 token,送入 Transformer Encoder 做全局建模。</p><p>但我们也提到了, ViT 存在一个绕不开的痛点:</p>
<blockquote>
<p><strong>没有足够大的数据规模,ViT 往往很难训练得好。</strong></p>
</blockquote>
<p>而用范式角度来说,这是因为<strong>ViT 本质上是一种“弱先验、强数据驱动”的建模方式。</strong></p>
<p>再展开一些,对于这个问题:</p>
<blockquote>
<p><strong>为什么 ViT 需要大量数据才能表现良好?而 CNN 在小数据下却依然有效?</strong></p>
</blockquote>
<p>我们在高光谱成像的内容中已经展开过先验相关的内容,知道卷积网络把大量视觉先验写进了结构里,而 ViT 更偏“数据驱动”,空间关系主要靠统计学习出来,因此对数据规模和训练配方高度敏感。</p>
<p>但单靠这些笼统的概念就直接进入 ViT 的下一步改进还是略显单薄。<br>
所以本篇主要介绍 DL 中的两个重要概念:</p>
<ol>
<li><strong>归纳偏置(Inductive Bias)</strong></li>
<li><strong>蒸馏(Distillation)</strong></li>
</ol>
<p>在了解这些内容后就可以较通畅地进入 ViT 的其中一种改进逻辑。</p>
<h1 id="1-先验信息和归纳偏置">1. 先验信息和归纳偏置</h1>
<p>首先,我们需要明白两个高度相关但并不等价的概念:<strong>先验信息(Prior Knowledge) 与 归纳偏置(Inductive Bias)</strong>。</p>
<h2 id="11-什么是先验信息">1.1 什么是先验信息?</h2>
<p>再简单复述一遍,用一句话定义:</p>
<blockquote>
<p><strong>先验信息是我们在学习之前就已经知道的“关于世界的规律”。</strong></p>
</blockquote>
<p>它并不来源于数据,而是<strong>来源于经验或认知</strong>。<br>
例如在视觉任务中,我们天然知道:图像是连续的、相邻像素之间更相关、物体具有结构等等基本认知,简单展开一些如下:</p>
<ol>
<li><strong>结构先验:</strong><br>
人脸中眼睛在鼻子上方,嘴巴在鼻子下方,整体呈现稳定的空间排列关系。</li>
<li><strong>局部相关性先验:</strong><br>
一张图像中,相邻像素通常属于同一物体,例如一片天空区域,其颜色和纹理在局部范围内是平滑且相似的,而不会突然剧烈变化。</li>
<li><strong>连续性先验:</strong><br>
图像中的边缘和轮廓通常是连续的,比如一条道路或物体边界,不会在相邻位置随机中断或跳跃。</li>
</ol>
<p>这些都属于<strong>对真实世界的描述</strong>,也就是先验信息。</p>
<h2 id="12-什么是归纳偏置">1.2 什么是归纳偏置?</h2>
<p>相比之下,归纳偏置是一个更“模型视角”的概念:</p>
<blockquote>
<p><strong>归纳偏置是模型在学习过程中“更倾向于选择某一类解”的机制,它通常来源于先验信息。</strong></p>
</blockquote>
<p>它并不是具体的知识,而是<strong>我们根据先验在模型中进行的设计</strong>,从而使得:模型更容易学到什么、模型更不容易学到什么。<br>
总结来说就是:<strong>归纳偏置决定了模型的“学习方向”。</strong></p>
<p>到这里,我们就可以较完善地解释最初的问题:<strong>为什么 ViT 需要大量数据才能表现良好?而 CNN 在小数据下却依然有效?</strong></p>
<h2 id="13-cnn-和-vit">1.3 CNN 和 ViT</h2>
<p>先用刚刚的概念来总结下: <strong>CNN 和 ViT,本质上是在“是否引入归纳偏置”这个问题上的两种不同选择。</strong><br>
这里要先强调一下,由归纳偏置引起的数据依赖性只是相对而言,基于 DL 的方法本身都是数据驱动的。</p>
<p>先说 CNN ,我们对 CNN 的建模逻辑本身就是在做一件事:<strong>把先验信息写进模型结构,而这就是归纳偏置的体现。</strong></p>
<p>具体展开对照:</p>
<ol>
<li>因为局部性先验,我们设计了<strong>卷积核</strong>,来限制模型只能关注局部区域。</li>
<li>我们使用<strong>网络层级结构</strong>,从局部逐步组合成全局,其实也是结构性先验的体现。</li>
</ol>
<p>你会发现:<strong>CNN 在训练开始之前,就已经被“规定”了如何理解图像。</strong></p>
<p>再从数学角度展开:CNN 的学习并不是在一个完全自由的空间中进行,而是<strong>在一个被强约束的函数空间中寻找解。</strong></p>
<p>这就是强归纳偏置,直接结果就是即使数据不多,模型也能较快学到合理结构更容易收敛,不容易学偏。</p>
<p>但代价就是<strong>模型的表达能力被结构限制</strong>,这其实限制了模型的<strong>上限</strong>,因为模型的学习逻辑不一定非要按照我们人类理解的逻辑来进行。</p>
<p><img src="https://img2024.cnblogs.com/blog/3708248/202604/3708248-20260410165106271-498293349.png" alt="image.png" loading="lazy"><br>
(有段时间没用 GPT 生成中文配图了,大概四个月前生成的图中的中文还是有很多错乱的,真的是在不停地进化。)</p>
<p>而在 ViT 中,我们几乎做了相反的选择:<strong>尽量不把先验写进结构,而是交给数据去学习。</strong><br>
体现就是不使用卷积,所有 patch token 通过注意力直接全局交互,这就是在空间层次上更弱的归纳配置。</p>
<p>也就是说: <strong>模型一开始并不知道“什么是局部结构”,也不知道“什么是空间关系”。</strong><br>
因此,ViT 的学习过程是<strong>在一个几乎不受约束的巨大函数空间中搜索解。</strong></p>
<p>这带来的是表达能力更强,但训练难度显著增加对数据规模和训练策略高度敏感。</p>
<p>久违地打个比方:<br>
CNN 像是“带着地图找路”,有<strong>大致</strong>的方向,可以更稳定地训练寻路能力。<br>
而 ViT 就像是“在未知环境中反复试错”,大量训练后,就拥有了更强的寻路能力。<br>
<strong>而“地图”就是归纳偏置。</strong></p>
<h2 id="14-小结">1.4 小结</h2>
<p>把这部分的内容总结为规律就是:</p>
<blockquote>
<p><strong>归纳偏置越弱,模型对数据的依赖就越强。</strong></p>
</blockquote>
<p>到这里,自然就有了下一个问题:</p>
<blockquote>
<p><strong>如果我们不想改变 ViT 的结构,又希望它在小数据上表现更好,该怎么办?</strong></p>
</blockquote>
<p>答案就是下一部分的内容:<strong>蒸馏</strong>。</p>
<h1 id="2-蒸馏">2. 蒸馏</h1>
<p>如果用一句话来概括:<strong>蒸馏就是让一个“小模型”去模仿一个“大模型”的输出,从而学到更好的决策能力。</strong></p>
<p>单看这句话,你可能会联想到我们之前介绍过的迁移学习。它们看起来都是在“借助一个强模型的能力”,但本质逻辑是不同的:</p>
<ul>
<li><strong>迁移学习</strong>:把“已经学到的参数”拿过来用。</li>
<li><strong>蒸馏</strong>:不直接用参数,而是<strong>模仿模型的行为(输出分布)</strong>。</li>
</ul>
<p>在详细展开前,我们先统一两个核心角色:</p>
<ol>
<li><strong>Teacher(教师模型)</strong>:通常是一个性能较强、已经训练好的模型。</li>
<li><strong>Student(学生模型)</strong>:我们真正要训练的目标模型。</li>
</ol>
<p>这种命名也是相关领域文献内的主流称呼,下面就来展开蒸馏的思路。<br>
<strong>要提前说明的是,蒸馏技术的逻辑是相通的,但其本身存在多种形式,在不同类型的任务中的实现方式也不同,这里我们使用最基础的分类任务来演示:</strong></p>
<h2 id="21-准备-teacher">2.1 准备 Teacher</h2>
<p>要进行蒸馏,<strong>首先要完成对 Teacher 的准备</strong>。</p>
<p>Teacher 通常是一个<strong>已经训练好的强模型</strong>,最常见情况就是<strong>直接用预训练好的 Teacher</strong>。</p>
<p>当然,如果在某些研究或特殊任务中,没有现成的强模型可用,那这时的流程就是先训练一个性能尽可能好的 Teacher,再用它去蒸馏 Student 。<br>
你可能会觉得这种方式有些没必要,<strong>但这是因为我们通常不是在追求最强模型,而是在追求“足够强 + 足够便宜”的模型。</strong></p>
<p>总之,这里的一个重要原则就是:</p>
<blockquote>
<p><strong>Teacher 不一定要“大”,但一定要“比 Student 更可靠”。</strong></p>
</blockquote>
<p>不然 Student 学到的是错误知识,蒸馏反而拖累性能。</p>
<h2 id="22-软标签-soft-label">2.2 软标签 soft label</h2>
<p>准备好 Teacher 后,第二步就是在 Teacher 上运行数据,获取<strong>软标签</strong>。</p>
<p>不难理解,它的操作是这样的:<br>
对于同一张输入图像, Teacher 不仅给出最终类别。还会输出一个完整的概率分布:</p>
<p></p><div class="math display">\
\]</div><p></p><p>这个分布包含了<strong>类别之间的相似关系信息</strong>。<br>
比如:</p>
<p></p><div class="math display">\[\text{cat} = 0.7,\quad \text{dog} = 0.2,\quad \text{car} = 0.1
\]</div><p></p><p>在这你会发现,这里<strong>其实就是直接获取经过输出层 softmax 得到的概率分布</strong>。</p>
<p>还没完,在实际蒸馏中,通常<strong>会引入一个温度参数 <span class="math inline">\(T\)</span> 来“软化”软标签和 Students 输出后再进行相关计算</strong>:</p>
<p></p><div class="math display">\[p_i = \frac{e^{z_i / T}}{\sum_j e^{z_j / T}}
\]</div><p></p><p>这样,当设计 <span class="math inline">\(T > 1\)</span> 时,概率分布就会变得更平滑、小概率类别被“放大”,让类比间的差距更明显。<br>
实际上,<strong>它是在真实运行中经常调试的超参数</strong>。<br>
<img src="https://img2024.cnblogs.com/blog/3708248/202604/3708248-20260410165106453-916740833.png" alt="image.png" loading="lazy"></p>
<p>总之,我们通过 Teacher 获取了对数据集的软标签,从而<strong>得到了哪些类别“接近正确”、模型的“犹豫程度”、决策边界的大致结构这类更细节的信息。</strong></p>
<h2 id="23-student-模仿-teacher">2.3 Student 模仿 Teacher</h2>
<p>这里就是蒸馏的核心逻辑,其实用简单的话来解释就是:</p>
<blockquote>
<p><strong>设计 Student 的损失函数,让 Student 的训练拟合目标除了学习真正标签,还学习 Teacher 输入的软标签,即拟合 Teacher 输出的概率分布。</strong></p>
</blockquote>
<p>下面的数学公式较为繁琐,先理解了这部分的逻辑后,就问题不大了:</p>
<p>首先要介绍的是 <strong>KL 散度</strong>(KL divergence,全称 Kullback–Leibler divergence)。<br>
对于离散分布,KL散度定义为:</p>
<p></p><div class="math display">\[D_{KL}(P \parallel Q) = \sum_{i} P(i)\log \frac{P(i)}{Q(i)}
\]</div><p></p><p>而在实际实现中为了简化运算,其等价形式是:</p>
<p></p><div class="math display">\[\mathcal{L}_{KD} = - \sum p_t \log p_s
\]</div><p></p><p>整体在蒸馏语境里通常写成:</p>
<p></p><div class="math display">\[D_{KL}(p_t \parallel p_s)
\]</div><p></p><p>其中:</p>
<ul>
<li><span class="math inline">\(p_t\)</span>:Teacher 的分布(真实参考)。</li>
<li><span class="math inline">\(p_s\)</span>:Student 的分布(要学习的对象)。</li>
</ul>
<p>语义上说,KL在这里的作用就是衡量 <strong>Student 的预测分布和 Teacher 的分布之间的差异有多大。</strong><br>
看个实例:</p>
<table>
<thead>
<tr>
<th>类别</th>
<th><span class="math inline">\(p_t\)</span> (Teacher)</th>
<th><span class="math inline">\(p_s\)</span> (Student)</th>
<th>比值 <span class="math inline">\(\frac{p_t}{p_s}\)</span></th>
<th><span class="math inline">\(\log(\frac{p_t}{p_s})\)</span></th>
<th><span class="math inline">\(p_t \cdot \log(\frac{p_t}{p_s})\)</span></th>
</tr>
</thead>
<tbody>
<tr>
<td>Cat</td>
<td>0.7</td>
<td>0.5</td>
<td>1.4</td>
<td>0.336</td>
<td>0.235</td>
</tr>
<tr>
<td>Dog</td>
<td>0.2</td>
<td>0.4</td>
<td>0.5</td>
<td>-0.693</td>
<td>-0.139</td>
</tr>
<tr>
<td>Car</td>
<td>0.1</td>
<td>0.1</td>
<td>1.0</td>
<td>0.000</td>
<td>0.000</td>
</tr>
</tbody>
</table>
<p>最终:</p>
<p></p><div class="math display">\[D_{KL}(p_t \parallel p_s) = 0.235 - 0.139 + 0 = 0.096
\]</div><p></p><p>这个结果的组成逻辑是这样的:</p>
<ol>
<li>Cat:Student 给的概率 <strong>偏低(0.5 < 0.7), 有误差</strong>。</li>
<li>Dog:Student 给的概率 <strong>偏高(0.4 > 0.2),有误差</strong>。</li>
<li>Car:没问题。</li>
</ol>
<p>显然,结果越小,就代表两种分布越接近。</p>
<p>而普通分类任务中,交叉熵损失函数如下:</p>
<p></p><div class="math display">\[\mathcal{L}_{CE} = -\sum_{i=1}^{C} y_i \log p_i
\]</div><p></p><p>最终,Student 的损失函数就是二者的组合:</p>
<p></p><div class="math display">\[\mathcal{L} = \alpha \mathcal{L}_{CE} + (1 - \alpha)\mathcal{L}_{KD}
\]</div><p></p><p>其中,<span class="math inline">\(\alpha\)</span> 为调节权重的超参数,两项损失分布代表:</p>
<ol>
<li><span class="math inline">\(\mathcal{L}_{CE}\)</span>:告诉 Student “标准答案是什么”。</li>
<li><span class="math inline">\(\mathcal{L}_{KD}\)</span>:告诉 Student “一个更强模型是怎么思考的”。</li>
</ol>
<p>在原始蒸馏方法中,为了补偿前面的 <span class="math inline">\(T\)</span> 对梯度的缩放,损失还会引入 <span class="math inline">\(T^2\)</span> 进行修正,但在现代实践中影响较小,常常省略,了解即可。</p>
<p>如此开始训练传播,最终便可得到蒸馏模型 Student 。</p>
<h2 id="24-小结">2.4 小结</h2>
<p>其实你会发现蒸馏是一种取巧的逻辑:<strong>对一个强大的模型,我直接去学习你的答案分布。</strong><br>
但实际上,蒸馏确实有其理论支持和实际价值,而这里展示的也只是一种较原始的逻辑,之后我们再详细展开。</p>
<p>回到 ViT,我们已经知道了它的问题在于“搜索空间太自由”。那么蒸馏在这里的作用就是:</p>
<blockquote>
<p><strong>人为引入一个“软约束”,缩小搜索空间,使优化更稳定,从而减少数据依赖。</strong></p>
</blockquote>
<p>这种逻辑实际上仍然是在利用 Teacher 的归纳配置。<br>
同样的,当 Teacher 本身存在偏差时,这种约束也会间接限制性能上限,因此,对 <span class="math inline">\(\alpha\)</span> 的调试也至关重要。</p>
<p>了解完两个概念后,就可以继续 ViT 的下一步改进了。</p><br><br>
来源:https://www.cnblogs.com/Goblinscholar/p/19847836
頁:
[1]