AIGC: Progressive Distillation 笔记
Google 出品,必属精品?
DDIM 知识蒸馏(Knowledge Distillation)
我们先从 DDIM 的知识蒸馏开始(2101.02388),在这个知识蒸馏的设定里面,我们有一个老师 (teacher) 和一个学生 (student),学生的目标是让自己的输出 尽量地接近老师的输出
,用数学公式表达,就是最小化:
另外,知识蒸馏有一个要求是,输出需要是确定的 (deterministic),所以这里采用的是 DDIM 的设定。
逐步蒸馏(Progressive Distillation)

逐步蒸馏(2202.00512)的设定很简单,相比于上面知识蒸馏的一步到位,逐步蒸馏采用的是分步进行蒸馏——首先有一个通过 N 步 DDIM 训练好的老师,然后有一个长得和老师一模一样的学生,想要以自己 1 步的输出去贴近老师 2 步的输出(意味着学生 DDIM 只需要 步),当这个学生学习结束以后,这个学生就成了新的老师,然后重复如上的过程。
论文对于扩散过程,用了一个更广泛的设定 . 我们通常所见到的 Variance Preserving 扩散过程,是其在
时的特例。
是所谓的 latent, 其实就是
加噪后的数据.
.
这里,我们在离散时间上进行训练和蒸馏,并且采用余弦方案 ,
代表了纯高斯噪声
(注意下标 t 的范围是从0到1)。
我们这里再定义一个信噪比 (Signal-to-Noise Ratio) . 在
的时候,很明显
,
, 故信噪比为 0.
Loss
针对 loss 函数我们有
代表了在 t 时间点所生成的图片。在公式做了如上的变形之后,我们可以把 loss 看成是在
空间里面的函数(预测图像和原图像的距离),而信噪比则控制了 loss 的权重 (weight). 这里我们把这个权重称作权重函数 (weighting function). 当然,我们还可以设计各种不同的权重函数。
这里,论文讨论了一个很有趣的现象——当 时(即扩散初期),因为
, 所以
任何一点小的波动都会被超级放大。在蒸馏的初期,因为我们的步数很多,早期的一些的错误会在后期被修复;但是越往下蒸馏,步数越少的时候,这种情况就要出问题了。在极端的情况下,如果我们这个逐步蒸馏,进行到只剩下一步了(意味着直接从纯高斯噪声一步生成图片),那么这个时候,整个 loss 也变成 0 了,学生就学不到任何东西了。
对此,论文里面有三种解决方案:
直接预测 x (绕过了
在分母上的问题)
预测
的同时,也预测
,然后用公式
生成图片。(两种渠道预测的
加权求和)
预测
, 然后

另外,论文里面还提出了两种可行的 loss 的方案:
Truncated SNR:
SNR+1:
DDIM Angular Parameterization
这里,我们对 DDIM 从另一个角度进行切入 ,所以
. 显然,由
, 我们可得
.
接下来,我们定义 的速度 (velocity) 为:
利用三角函数的那些定理(初高中知识哦),对上面的公式变形后,我们可以得到:
在这里,我们再定义一个预测速度 (predicted velocity):
根据公式 ,我们有:
所以这里解释了上一节的解决方案3的公式由来。
接下来我们要做的只是一些公式变形了,最终我们会得到:

学习目标
对于每一步的更新,其方法是可以有很多种的。
这里,论文里面使用的更新公式为:
, 对其求导的话就可以得到
.(这里论文假定了 score function
可以用
来近似;详细过程见论文附录)
有了更新公式以后,接下来的事情就简单了,根据公式先计算前一步的 , 根据
再计算前一步的
. 然后我们计算目标
,最小化上述的 loss. 大功告成。

完。
注:B站的公式编辑器频繁抽风,如果遇到一些 tex parse error 之类的错误时,尝试刷新一下页面。