欢迎光临散文网 会员登陆 & 注册

AIGC: PNDM (Pseudo Numerical Method for Diffusion Models) 笔记

2023-08-04 09:58 作者:刹那-Ksana-  | 我要投稿

又是一种新的采样方法,又是一个新的公式地狱?

Runge-Kutta 方法

Runge-Kutta 方法是一种在欧拉方法的基础上,比欧拉方法精度更高的解常微分方程(Ordinary Differential Equation)%5Cfrac%7Bdx%7D%7Bdt%7D%3Df(x%2Ct)数值解法(Numerical Method)。

其一般的形式为

x_%7Bt%2B1%7D%3Dx_t%2BhAk_1%2BhBk_2%2BhCk_3%2B...

k_1%2Ck_2%2C... 有其固定的形式,A, B, C,... 则是待定系数。这里有多少个 k 就代表了多少阶 Runge-Kutta. 有关 Runge-Kutte 的公式推导,链接将放在文章末尾。

论文里面使用的是4阶 Runge-Kutta 方法:

%5Cbegin%7Balign*%7D%0A%26%20x_%7Bt%2B%5Cdelta%7D%3Dx_t%2B%5Cfrac%7B%5Cdelta%7D%7B6%7D(k_1%2B2k_2%2B2k_3%2Bk_4)%5C%5C%0A%26%20k_1%20%3D%20f(x_t%2Ct)%20%5C%5C%0A%26%20k_2%20%3D%20f(x_t%2B%5Cfrac%7B%5Cdelta%7D%7B2%7Dk_1%2Ct%2B%5Cfrac%7B%5Cdelta%7D%7B2%7D)%20%5C%5C%0A%26%20k_3%20%3D%20f(x_t%2B%5Cfrac%7B%5Cdelta%7D%7B2%7Dk_2%2C%20t%2B%5Cfrac%7B%5Cdelta%7D%7B2%7D)%20%5C%5C%0A%26%20k_4%20%3D%20f(x_t%2B%5Cdelta%20k_3%2C%20t%2B%5Cdelta)%0A%5Cend%7Balign*%7D

Linear Multistep 方法

Linear Multistep 和 Runge-Kutta 差不多,都是 ODE 的数值解法,其一般形式为:

%5Cbegin%7Balign*%7D%0A%26%20%5Calpha_s%20y_%7Bt%2Bs%7D%20%2B%20%5Calpha_%7Bs-1%7D%20y_%7Bt%2Bs-1%7D%20%2B%20...%20%2B%20%5Calpha_0%20y_0%3D%20%5C%5C%0A%26%20h%5Cbeta_s%20f_%7Bt%2Bs%7D%2Bh%20%5Cbeta_%7Bs-1%7D%20f_%7Bt%2Bs-1%7D%20%2B%20...%20%2B%20h%20%5Cbeta_0%20f_t%0A%5Cend%7Balign*%7D

论文里面采用的形式为:

x_%7Bt%2B%5Cdelta%7D%3Dx_t%2B%5Cfrac%7B%5Cdelta%7D%7B24%7D(55f_t-59f_%7Bt-%5Cdelta%7D%2B37f_%7Bt-2%5Cdelta%7D-9f_%7Bt-3%5Cdelta%7D)%2C%5C%20f_%7Bt%7D%3Df(x_t%2C%20t)

ODE 与数值解

之前在 DDIM 的笔记里面,最后一节里面也给出了当 %5Csigma%3D0 时,DDIM 的 ODE 形式,在这里,论文里面对于公式做了一些变形:

%5Cfrac%7Bdx%7D%7Bdt%7D%20%3D%20-%5Cbar%7B%5Calpha%7D'(t)%5Cleft(%5Cfrac%7Bx(t)%7D%7B2%5Cbar%7B%5Calpha%7D(t)%7D-%5Cfrac%7B%5Cepsilon_%5Ctheta(x(t)%2Ct)%7D%7B2%5Cbar%7B%5Calpha%7D(t)%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D(t)%7D%7D%5Cright)

至于公式是怎么来的,不是特别重要,故省略步骤了。

有了上述的 ODE 公式以后,可以套用一些数值解法进行求解。但是论文发现,数值解法下如果大幅减少步数(提速)会引入大量的噪音。

针对这种情况,研究团队做了调查,发现问题来源有两个。

首先,上述 ODE 公式和 %5Cepsilon_%5Ctheta 只在一小片弧形的区域里面明确定义(well-defined),在这块弧形区域之外,将没有足够的样本让 %5Cepsilon_%5Ctheta 去贴合噪音,而传统的数值法都是沿着直线计算的,所以会引入误差。

只有在红色的弧线区域里面,才明确定义。区域之外,都会有很大的误差。

其次,线性计划下的 %5Cbeta_t,当 t 趋近于 0 时,ODE 的值将会趋于无穷,于是违反了数值解法的先决条件。

解决方法

针对以上问题,论文里面提出了相应的解决方案。

针对第一点,论文提出,应该在一个流形上去求解(把弧线拉直,或者说把直线拉弯)。这里,我们将传统的数值解法(即 Runge-Kutta, Linear Multistep 等等)分为两部分——梯度部分(Gradient Part)f' 和线性的转移部分(Transfer Part)%5Ccolor%7Bpurple%7D%7Bx_%7Bt%2B%5Cdelta%7D%7D%3D%5Ccolor%7Bgreen%7D%7Bx_%7Bt%7D%7D%2B%5Cdelta%20%5Ccolor%7Bbrown%7D%7Bf'%7D. 如果转移部分不是线性的,那么这种数值解法论文里面称作伪数值方法(Pseudo Numerical Method)。

很明显,如果我们的目标分布是弧形的话,我们不能使用线性的转移方法。并且,如果我们的梯度部分如果是精确的(即,有足够的数据去衡量噪音 %5Cepsilon_%5Ctheta),那么我们的转移部分也应该尽量地贴合流形。

所以,在这里如果假设 x_t%3D%20%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7Dx_0%20%2B%20%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_t%7D%5Cepsilon, 且我们的噪音 %5Cepsilon_%5Ctheta 是精确的,即 %5Cepsilon_%5Ctheta%3D%5Cepsilon, 那么根据 DDIM 的逆向过程:

 x_%7Bt-%5Cdelta%7D%20%3D%20%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D%5Cleft(%5Cfrac%7Bx_t-%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_t%7D%5Cepsilon_%5Ctheta(x_t%2C%20t)%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D%7D%5Cright)%20%2B%20%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D-%5Csigma%5E2_t%7D%5Cepsilon_%5Ctheta(x_t%2C%20t)%20%2B%20%5Csigma_t%20%5Cepsilon_t

(注意这里 %5Csigma_t%3D0

可以推导得出,对于任意的 t'%5Cleq%20t, 有:

%20%20%20%5Cbegin%7Bsplit%7D%0A%20%20%20%20%20%20x_%7Bt'%7D%20%26%3D%20%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt'%7D%7D%5Cleft(%5Cfrac%7Bx_t-%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_t%7D%5Cepsilon_%5Ctheta%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D%7D%5Cright)%20%2B%20%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_%7Bt'%7D%7D%5Cepsilon_%5Ctheta%20%5C%5C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%26%3D%20%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt'%7D%7D%5Cleft(%5Cfrac%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7Dx_0%20%2B%20%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_t%7D%5Cepsilon-%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_t%7D%5Cepsilon%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D%7D%5Cright)%20%2B%20%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_%7Bt'%7D%7D%5Cepsilon_%5Ctheta(x_t%2C%20t)%20%5C%5C%0A%20%20%20%20%20%20%20%20%20%20%20%20%20%26%3D%20%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt'%7D%7Dx_0%20%2B%20%5Csqrt%7B1-%5Cbar%7B%5Calpha%7D_%7Bt'%7D%7D%5Cepsilon.%0A%20%20%20%5Cend%7Bsplit%7D

所以,得出的结论就是,如果 %5Cepsilon 是精确的,那么 x_%7Bt'%7D 也是精确的。

如果将上面逆向过程的式子,两边同时减去 x_t, 那么就得到了如下的形式:

x_%7Bt-%5Cdelta%7D%20-%20%7Bx_t%7D%3D%20%5Cfrac%7B(%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D-%5Cbar%7B%5Calpha%7D_t)%20x_t%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D(%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D%2B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D)%7D%20-%20%0A%20%20%20%5Cfrac%7B(%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D-%5Cbar%7B%5Calpha%7D_t)%20%5Cepsilon_%5Ctheta(x_t%2C%20t)%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D(%5Csqrt%7B(1-%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D)%5Cbar%7B%5Calpha%7D_%7Bt%7D%7D%20%2B%20%5Csqrt%7B(1-%5Cbar%7B%5Calpha%7D_%7Bt%7D)%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D)%7D

然后我们对公式做一下变形,将 (%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D-%5Cbar%7B%5Calpha%7D_t) 拆分成 (%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D-%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D)(%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D%2B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D), 把等式左边的 x_t 移到等式右边,所以:

x_t%2B%5Cfrac%7B(%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D-%5Cbar%7B%5Calpha%7D_t)%20x_t%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D(%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D%2B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D)%7D%20%3D%5Cfrac%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7Dx_t%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D%7D%2B%5Cfrac%7B(%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D-%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D)%20x_t%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D%7D

进一步化简一下就得到了论文里面的形式:

x_%7Bt-%5Cdelta%7D%3D%20%5Cfrac%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D%7Dx_t%20-%20%0A%20%20%20%5Cfrac%7B(%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D-%5Cbar%7B%5Calpha%7D_t)%7D%7B%5Csqrt%7B%5Cbar%7B%5Calpha%7D_t%7D(%5Csqrt%7B(1-%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D)%5Cbar%7B%5Calpha%7D_%7Bt%7D%7D%20%2B%20%5Csqrt%7B(1-%5Cbar%7B%5Calpha%7D_%7Bt%7D)%5Cbar%7B%5Calpha%7D_%7Bt-%5Cdelta%7D%7D)%7D%5Cepsilon_t

所以,这就是我们转移部分的新公式,并且我们把这个公式命名为 %5Cphi(x_t%2C%20%5Cepsilon_t%2C%20t%2C%20t-%5Cdelta)。而 %5Cepsilon_t 则是梯度部分。

由于以上的公式是由 DDPM/DDIM 转换过来的,所以具备的一个特性是,当 t 趋向于 0 时,%5Cepsilon_%7B%5Ctheta%7D 将会越来越精确。于是顺带着解决了上面的第二个问题。

一张汇总表格,其中 PNDM 属于高阶非线性方法

梯度部分

如果只是上面的部分的话,其实没有什么特别的地方。因为整个公式其实是 DDIM 逆向过程的变形(说了那么多,原来都是白说了啊)。在这里,论文认为,梯度部分的计算可以借用传统的数值方法。

例如,在 linear multistep 方法下,我们有:

%5Cbegin%7Balign*%7D%0A%26e_t%20%3D%20%5Cepsilon_%5Ctheta(x_t%2C%20t)%5C%5C%0A%26e_t'%20%3D%20%5Cfrac%7B1%7D%7B24%7D(55e_t-59e_%7Bt-%5Cdelta%7D%2B37e_%7Bt-2%5Cdelta%7D-9e_%7Bt-3%5Cdelta%7D)%20%5C%5C%0A%26x_%7Bt%2B%5Cdelta%7D%20%3D%20%5Cphi(x_t%2C%20e_t'%2C%20t%2C%20t%2B%5Cdelta)%0A%5Cend%7Balign*%7D

论文里面把这种方法称作 PLMS(Pseudo Linear Multistep)。

由于 LMS 依赖于前三步,所以最初的三步需要用 Runge-Kutta 方法来算出来。Runge-Kutta 方法也有,称作 PRK,但是太长了不搬过来了,详见原论文。

基于 PLMS 的 F-PNDM 算法,F代表每一次计算要用到4步的数据

Second-order PNDM,每一步的计算利用的是2步的数据,但是原理差不多。只不过用的是2阶 Runge-Kutta 方法以及 2步形式的 Linear Multistep. 因为大同小异,就不搬过来了,内容见论文附录 A.3. 

这里是从官方的 Github上面 F-PNDM 摘录的上述算法的代码,以及一些个人加入的注释

一些推荐参考的资料

Runge-Kutte 的资料(英文),Numerical Approximations of Solutions to First-Order Equations -- Other Numerical Methods -- The Runge-Kutta Method 小节,里面还有例子

:https://www.sciencedirect.com/science/article/abs/pii/B9780128047767000024

Diffusers, 目前来说注释最完整,最容易理解的代码库:https://github.com/huggingface/diffusers/tree/main

官方的 PNDM Github, 里面还有 FON, 欧拉方法等其他采样方法:https://github.com/luping-liu/PNDM/tree/master

完。

AIGC: PNDM (Pseudo Numerical Method for Diffusion Models) 笔记的评论 (共 条)

分享到微博请遵守国家法律