Diffusion Model原理及代码解析
扩散模型(Diffusion Model)是一类生成模型,用于学习如何从噪声中生成数据。近年来,它们在图像生成、文本生成等领域取得了显著的进展。扩散模型的基本思想是通过逐步向数据添加噪声(通常是高斯噪声),让数据逐渐变得模糊,直至接近纯噪声的形式。然后,模型通过学习逆向过程,即从噪声逐步恢复出原始数据,从而生成新的样本。
目前有很多生成模型,比如GAN,VAE,Flow-based models,Diffusion models等等:
本文主要介绍的模型是Denoising Diffusion Probabilistic Models 。
一、基础知识
1.1 条件概率的一般形式
P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B , A ) P ( B ∣ A ) P ( A ) P ( B , C ∣ A ) = P ( B ∣ A ) P ( C ∣ A , B ) \begin{align}
P(A,B,C)=P(C|B,A)P(B,A)=P(C|B,A)P(B|A)P(A) \tag{1-1} \newline
P(B,C|A)=P(B|A)P(C|A,B) \tag{1-2}\hspace{4.25cm}
\end{align}
P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B , A ) P ( B ∣ A ) P ( A ) P ( B , C ∣ A ) = P ( B ∣ A ) P ( C ∣ A , B ) ( 1-1 ) ( 1-2 )
1.2 基于马尔科夫链假设的概率
马尔科夫链假设:简单来说就是研究对象在t t t 时刻的状态只与t − 1 t-1 t − 1 时刻的状态有关,而与之前的状态无关。
若满足马尔科夫链关系:A → B → C A \to B \to C A → B → C ,则有:
P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B ) P ( B ∣ A ) P ( A ) P ( B , C ∣ A ) = P ( B ∣ A ) P ( C ∣ B ) \begin{align*}
P(A,B,C)=P(C|B,A)P(B,A)=P(C|B)P(B|A)P(A) \tag{1-3}\newline
P(B,C|A)=P(B|A)P(C|B) \hspace{4.25cm} \tag{1-4}
\end{align*}
P ( A , B , C ) = P ( C ∣ B , A ) P ( B , A ) = P ( C ∣ B ) P ( B ∣ A ) P ( A ) P ( B , C ∣ A ) = P ( B ∣ A ) P ( C ∣ B ) ( 1-3 ) ( 1-4 )
1.3 贝叶斯公式
P ( A ∣ B ) = P ( B ∣ A ) P ( A ) P ( B ) (1-5) P(A|B)=P(B|A)\frac{P(A)}{P(B)} \tag{1-5}
P ( A ∣ B ) = P ( B ∣ A ) P ( B ) P ( A ) ( 1-5 )
1.4 高斯分布的KL散度公式
相对熵(relative entropy),又被称为Kullback-Leibler散度(Kullback-Leibler divergence)或信息散度(information divergence),是两个概率分布(probability distribution)间差异的非对称性度量。在信息理论中,相对熵等价于两个概率分布的信息熵(Shannon entropy)的差值。
对于两个单一变量的高斯分布p p p 和q q q 而言,他们的KL散度为:
K L ( p , q ) = l o g σ 2 σ 1 + σ 1 2 + ( μ 1 − μ 2 ) 2 2 σ 2 2 − 1 2 (1-6) KL(p,q)=log\frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+(\mu_1-\mu_2)^2}{2\sigma_2^2}-\frac{1}{2} \tag{1-6}
K L ( p , q ) = l o g σ 1 σ 2 + 2 σ 2 2 σ 1 2 + ( μ 1 − μ 2 ) 2 − 2 1 ( 1-6 )
其中:
p ∼ N ( μ 1 , σ 1 2 ) q ∼ N ( μ 2 , σ 2 2 ) p \thicksim \mathcal{N}(\mu_1,\sigma_1^2) \newline
q \thicksim \mathcal{N}(\mu_2,\sigma_2^2)
p ∼ N ( μ 1 , σ 1 2 ) q ∼ N ( μ 2 , σ 2 2 )
1.5 重参数化技巧(Reparameterization Trick)
重参数化技巧(Reparameterization Trick) 是一种在深度学习中常用的技术,尤其是在变分自动编码器(Variational Autoencoder, VAE)中,用来解决对随机变量进行梯度求导的问题。
1.5.1 重参数化的核心思想
重参数化的核心思想是将一个随机变量的采样过程表示为一个确定性函数加上一些外部噪声。通过这种方式,我们能够将梯度传播到随机变量的参数上。
具体来说,假设我们有一个高斯分布z ∼ N ( μ , σ 2 ) z \sim \mathcal{N}(\mu, \sigma^2) z ∼ N ( μ , σ 2 ) ,传统的采样方法会直接从这个高斯分布中生成样本z z z ,但这样做时,z z z 的生成过程包含了随机性,不能直接对μ \mu μ 和σ \sigma σ 求梯度,因为采样操作不可导。
重参数化技巧的核心思想是: 将z z z 的采样过程重新表示为一个确定性函数形式:
z = μ + σ ⋅ ϵ (1-7) z = \mu + \sigma \cdot \epsilon \tag{1-7}
z = μ + σ ⋅ ϵ ( 1-7 )
其中,ϵ \epsilon ϵ 是从标准正态分布ϵ ∼ N ( 0 , 1 ) \epsilon \sim \mathcal{N}(0,1) ϵ ∼ N ( 0 , 1 ) 中采样的。通过这种变换,随机性被分离到了ϵ \epsilon ϵ 中,而z z z 现在是由μ \mu μ 和σ \sigma σ 确定性地生成的。这样,虽然我们在计算z z z 时仍然依赖噪声ϵ \epsilon ϵ ,但这个公式对μ \mu μ 和σ \sigma σ 是可导的,可以通过反向传播计算出损失函数对μ \mu μ 和σ \sigma σ 的梯度。
需要注意的是:重参数化只是将采样过程重写为可导的形式,而没有改变分布本身,z z z 依然服从于正态分布N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N ( μ , σ 2 ) 。
1.5.2 重参数化的优势
可导性:
通过将随机变量表示为确定性变量的函数,重参数化技巧使得我们能够对采样过程进行梯度求解,从而可以使用常规的反向传播算法来优化目标函数。
稳定性:
重参数化技巧将采样和参数分离,使模型在优化时更稳定。例如,在VAE中,通过这种技巧可以更有效地学习潜在空间的分布。
高效性:
通过重参数化,我们可以避免复杂的蒙特卡罗方法,减少计算量并加快模型的训练速度。
1.5.3 示例
假设我们要优化一个由高斯分布N ( μ , σ 2 ) \mathcal{N}(\mu, \sigma^2) N ( μ , σ 2 ) 生成的随机变量z z z ,我们希望通过最小化某个损失函数L ( z ) L(z) L ( z ) 来学习μ \mu μ 和σ \sigma σ 。直接对z z z 进行优化很难,但通过重参数化:
z = μ + σ ⋅ ϵ , ϵ ∼ N ( 0 , 1 ) z = \mu + \sigma \cdot \epsilon, \epsilon \sim \mathcal{N}(0,1)
z = μ + σ ⋅ ϵ , ϵ ∼ N ( 0 , 1 )
我们可以将损失函数L ( z ) L(z) L ( z ) 写成L ( μ + σ ⋅ ϵ ) L(\mu + \sigma \cdot \epsilon) L ( μ + σ ⋅ ϵ ) ,然后通过反向传播来计算关于μ \mu μ 和σ \sigma σ 的梯度。
二、Diffusion扩散过程
2.1 Diffusion正向扩散过程(Forward Process)
Diffsion model的正向传播过程是一个向图片添加噪音的过程。
给定初始数据分布x 0 ∼ q ( x ) x_0 \sim q(x) x 0 ∼ q ( x ) ,可以不断向分布中添加T T T 次高斯噪声,得到x 1 , x 2 , . . . , x T x_1,x_2,...,x_T x 1 , x 2 , ... , x T ,如下图的q q q 过程。这里需要给定一系列的高斯分布方差的超参数{ β t ∈ ( 0 , 1 ) } t = 1 T \{\beta_t \in (0,1)\}_{t=1}^T { β t ∈ ( 0 , 1 ) } t = 1 T 。假设每次添加噪声的过程符合马尔科夫链假设,则有:
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q ( x 1 : T ∣ x 0 ) = ∏ t = 1 T q ( x t ∣ x t − 1 ) \begin{align*}
q(x_t|x_{t-1})=\mathcal{N}(x_t;\sqrt{1-\beta_t}x_{t-1},\beta_t\Iota)\tag{2-1} \newline
q(x_{1:T}|x_0)=\prod_{t=1}^Tq(x_t|x_{t-1}) \tag{2-2}\hspace{1.55cm}
\end{align*}
q ( x t ∣ x t − 1 ) = N ( x t ; 1 − β t x t − 1 , β t I ) q ( x 1 : T ∣ x 0 ) = t = 1 ∏ T q ( x t ∣ x t − 1 ) ( 2-1 ) ( 2-2 )
其中,β t \beta_t β t 是随着t t t 的增大而逐渐变大的,在原论文中,β t ∈ ( 0.0001 , 0.02 ) \beta_t \in (0.0001,0.02) β t ∈ ( 0.0001 , 0.02 ) 。也就是说,在前向扩散过程中,随着t t t 的增大,β t \beta_t β t 不断增大,也就是噪声的占比逐渐变多,x t x_t x t 越来越接近纯噪声。当T → ∞ T \to \infin T → ∞ 时,x T x_T x T 是完全的高斯噪声。
利用重参数技巧我们可以将添加噪声的过程用下面的公式表示:
x t = α t x t − 1 + 1 − α t z t − 1 (2-3) x_t = \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z_{t-1} \tag{2-3}
x t = α t x t − 1 + 1 − α t z t − 1 ( 2-3 )
其中:x t x_t x t 表示t t t 时刻的图像,x t − 1 x_{t-1} x t − 1 表示t − 1 t-1 t − 1 时刻的图像,z t − 1 z_{t-1} z t − 1 表示t − 1 t-1 t − 1 时刻添加的高斯噪声,其服从于高斯分布N ( 0 , 1 ) \mathcal{N}(0,1) N ( 0 , 1 ) 。α t = 1 − β t \alpha_t=1-\beta_t α t = 1 − β t 。
当我们不断带入前一时刻的表达式后可以得到:
x t = α t x t − 1 + 1 − α t z t − 1 = α t ( α t − 1 x t − 2 + 1 − α t − 1 z t − 2 ) + 1 − α t z t − 1 = α t α t − 1 x t − 2 + 1 − α t α t − 1 z t − 2 = . . . = α ˉ t x 0 + 1 − α ˉ t z t ˉ \begin{align*}
x_t &= \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z_{t-1} \\
&= \sqrt{\alpha_t}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_{t-2})+\sqrt{1-\alpha_t}z_{t-1} \\
&= \sqrt{\alpha_t\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_t\alpha_{t-1}}z_{t-2} \\
&= ... \\
&= \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\bar{z_t} \tag{2-4}
\end{align*} x t = α t x t − 1 + 1 − α t z t − 1 = α t ( α t − 1 x t − 2 + 1 − α t − 1 z t − 2 ) + 1 − α t z t − 1 = α t α t − 1 x t − 2 + 1 − α t α t − 1 z t − 2 = ... = α ˉ t x 0 + 1 − α ˉ t z t ˉ ( 2-4 )
其中:α ˉ t = ∏ i = 1 T α i \bar \alpha_t=\prod_{i=1}^T\alpha_i α ˉ t = ∏ i = 1 T α i ,z t ˉ ∼ N ( 0 , 1 ) \bar{z_t} \sim \mathcal{N}(0,1) z t ˉ ∼ N ( 0 , 1 ) 。
注意:在第三行中我们使用了一个技巧:两个正态分布X ∼ N ( μ 1 , σ 1 2 ) X \sim \mathcal{N}(\mu_1,\sigma_1^2) X ∼ N ( μ 1 , σ 1 2 ) 和Y ∼ N ( μ 2 , σ 2 2 ) Y \sim \mathcal{N}(\mu_2,\sigma_2^2) Y ∼ N ( μ 2 , σ 2 2 ) 叠加后的分布a X + b Y aX+bY a X + bY 的均值为a μ 1 + b μ 2 a\mu_1+b\mu_2 a μ 1 + b μ 2 ,方差为a 2 σ 1 2 + b 2 σ 2 2 a^2\sigma_1^2+b^2\sigma_2^2 a 2 σ 1 2 + b 2 σ 2 2 ,所以α t ( 1 − α t − 1 ) z t − 2 + 1 − α t z t − 1 \sqrt{\alpha_t(1-\alpha_{t-1})}z_{t-2}+\sqrt{1-\alpha_t}z_{t-1} α t ( 1 − α t − 1 ) z t − 2 + 1 − α t z t − 1 可以参数重整化为只含一个随机变量z z z 构成的1 − α t α t − 1 z \sqrt{1-\alpha_t\alpha_{t-1}}z 1 − α t α t − 1 z 的形式。
根据上面的公式我们可以发现,想象中我们添加了T T T 次噪声,但实际上我们可以只通过对原图添加一次噪声就得到t t t 时刻的图像。
根据以上推导过程,我们可以得到任意时刻的x t x_t x t 都满足:
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) (2-5) q(x_t|x_0)=\mathcal{N}(x_t;\sqrt{\bar\alpha_t}x_0,(1-\bar\alpha_t)\Iota) \tag{2-5}
q ( x t ∣ x 0 ) = N ( x t ; α ˉ t x 0 , ( 1 − α ˉ t ) I ) ( 2-5 )
这为我们后面的预测模型打下了基础。
2.2 Diffusion逆向扩散过程(Reverse Process)
前面我们说到前向传播过程是一个向图片添加高斯噪声的过程,那么逆向传播过程就是为图片去噪的过程。
如果我们能够逐步得到逆转后的分布q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q ( x t − 1 ∣ x t ) ,就可以从完全的标准高斯分布x T ∼ N ( 0 , I ) x_T \sim \mathcal{N}(0,\Iota) x T ∼ N ( 0 , I ) 还原出原图分布x 0 x_0 x 0 .在已有的文献中已经证明了如果q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q ( x t ∣ x t − 1 ) 满足高斯分布且β t \beta_t β t 足够小,q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q ( x t − 1 ∣ x t ) 仍然是一个高斯分布。然而我们无法简单推断q ( x t − 1 ∣ x t ) q(x_{t-1}|x_t) q ( x t − 1 ∣ x t ) ,因此我们使用深度学习模型(参数为θ \theta θ ,目前主流是U-Net+attention的结构)去预测这样的一个逆向的分布p θ p_\theta p θ (类似VAE):
p θ ( x 0 : T ) = p ( x T ) ∏ t = 1 T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , ∑ θ ( x t , t ) ) \begin{align*}
p_\theta(x_{0:T}) = p(x_T)\prod_{t=1}^Tp_\theta(x_{t-1}|x_t) \tag{2-6}\newline
p_\theta(x_{t-1}|x_t) = \mathcal{N}(x_{t-1};\mu_\theta(x_t,t),\sum_\theta(x_t,t)) \tag{2-7}\hspace{-1.54cm}
\end{align*}
p θ ( x 0 : T ) = p ( x T ) t = 1 ∏ T p θ ( x t − 1 ∣ x t ) p θ ( x t − 1 ∣ x t ) = N ( x t − 1 ; μ θ ( x t , t ) , θ ∑ ( x t , t )) ( 2-6 ) ( 2-7 )
我们如何根据x t x_t x t 图像反向得到x t − 1 x_{t-1} x t − 1 呢?我们已经知道q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q ( x t ∣ x t − 1 ) ,那么我们就可以根据贝叶斯公式得到:
q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t − 1 ) q ( x t ) (2-8) q(x_{t-1}|x_t)=q(x_t|x_{t-1})\frac{q(x_{t-1})}{q(x_t)} \tag{2-8}
q ( x t − 1 ∣ x t ) = q ( x t ∣ x t − 1 ) q ( x t ) q ( x t − 1 ) ( 2-8 )
但是在公式(2-8)中,q ( x t ) q(x_t) q ( x t ) 和q ( x t − 1 ) q(x_{t-1}) q ( x t − 1 ) 是未知的,由公式(2-4)我们知道可以由x 0 x_0 x 0 得到每一时刻的图像,所以我们可以得到:
q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 z q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t z q ( x t ∣ x t − 1 , x 0 ) = α t x t − 1 + 1 − α t z \begin{align*}
q(x_{t-1}|x_0) = \sqrt{\bar\alpha_{t-1}}x_0+\sqrt{1-\bar\alpha_{t-1}}z \tag{2-9}\newline
q(x_t|x_0) = \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}z \hspace{1.05cm} \tag{2-10}\newline
q(x_t|x_{t-1},x_0) = \sqrt{\alpha_t}x_{t-1}+\sqrt{1-\alpha_t}z \tag{2-11}
\end{align*}
q ( x t − 1 ∣ x 0 ) = α ˉ t − 1 x 0 + 1 − α ˉ t − 1 z q ( x t ∣ x 0 ) = α ˉ t x 0 + 1 − α ˉ t z q ( x t ∣ x t − 1 , x 0 ) = α t x t − 1 + 1 − α t z ( 2-9 ) ( 2-10 ) ( 2-11 )
那么公式(2-8)就可以化为:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) (2-12) q(x_{t-1}|x_t,x_0) = q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \tag{2-12}
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) ( 2-12 )
且:
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ t ~ ( x t , x 0 ) , β t ~ I ) (2-13) q(x_{t-1}|x_t,x_0) = \mathcal{N}(x_{t-1};\tilde{\mu_t}(x_t,x_0),\tilde{\beta_t}\Iota) \tag{2-13}
q ( x t − 1 ∣ x t , x 0 ) = N ( x t − 1 ; μ t ~ ( x t , x 0 ) , β t ~ I ) ( 2-13 )
根据高斯分布的概率密度函数:
f ( x ) = 1 2 π σ e − ( x − μ ) 2 2 σ 2 = 1 2 π σ e − 1 2 [ 1 σ 2 x 2 − 2 μ σ 2 x + μ 2 σ 2 ] \begin{align*}
f(x) &= \frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \\
&= \frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{1}{2}[\frac{1}{\sigma^2}x^2-\frac{2\mu}{\sigma^2}x+\frac{\mu^2}{\sigma^2}]} \tag{2-14}
\end{align*}
f ( x ) = 2 π σ 1 e − 2 σ 2 ( x − μ ) 2 = 2 π σ 1 e − 2 1 [ σ 2 1 x 2 − σ 2 2 μ x + σ 2 μ 2 ] ( 2-14 )
我们就将公式(2-12)化为:
q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t − 1 ∣ x 0 ) q ( x t ∣ x 0 ) ∝ e x p ( − 1 2 ( ( x t − α t x t − 1 ) 2 β t + ( x t − 1 − α ˉ t − 1 x 0 ) 2 1 − α ˉ t − 1 − ( x t − α ˉ t x 0 ) 2 1 − α ˉ t ) ) = e x p ( − 1 2 ( ( α t β t + 1 1 − α ˉ t − 1 ) x t − 1 2 − ( 2 α t β t x t + 2 α ˉ t − 1 1 − α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ) ) ) \begin{align*}
q(x_{t-1}|x_t,x_0) &= q(x_t|x_{t-1},x_0)\frac{q(x_{t-1}|x_0)}{q(x_t|x_0)} \\
&\propto exp(-\frac{1}{2}(\frac{(x_t-\sqrt{\alpha_t}x_{t-1})^2}{\beta_t}+\frac{(x_{t-1}-\sqrt{\bar\alpha_{t-1}x_0})^2}{1-\bar\alpha_{t-1}}-\frac{(x_t-\sqrt{\bar\alpha_t}x_0)^2}{1-\bar\alpha_t})) \\
&= exp(-\frac{1}{2}((\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar\alpha_{t-1}})x_{t-1}^2-(\frac{2\sqrt{\alpha_t}}{\beta_t}x_t+\frac{2\sqrt{\bar\alpha_{t-1}}}{1-\bar\alpha_{t-1}}x_0)x_{t-1}+C(x_t,x_0))) \tag{2-15}
\end{align*} q ( x t − 1 ∣ x t , x 0 ) = q ( x t ∣ x t − 1 , x 0 ) q ( x t ∣ x 0 ) q ( x t − 1 ∣ x 0 ) ∝ e x p ( − 2 1 ( β t ( x t − α t x t − 1 ) 2 + 1 − α ˉ t − 1 ( x t − 1 − α ˉ t − 1 x 0 ) 2 − 1 − α ˉ t ( x t − α ˉ t x 0 ) 2 )) = e x p ( − 2 1 (( β t α t + 1 − α ˉ t − 1 1 ) x t − 1 2 − ( β t 2 α t x t + 1 − α ˉ t − 1 2 α ˉ t − 1 x 0 ) x t − 1 + C ( x t , x 0 ))) ( 2-15 )
再根据公式(2-14)的形式我们可以得到公式(2-13)中的μ ~ ( x t , x 0 ) \tilde{\mu}(x_t,x_0) μ ~ ( x t , x 0 ) 和β t ~ \tilde{\beta_t} β t ~
β t ~ = 1 / ( α t β t + 1 1 − α ˉ t − 1 ) = 1 − α ˉ t − 1 1 − α ˉ t β t μ t ~ ( x t , x 0 ) = ( α t β t x t + α ˉ t 1 − α ˉ t x 0 ) / ( α t β t + 1 1 − α ˉ t − 1 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 β t 1 − α ˉ t x 0 \begin{align*}
\tilde{\beta_t} = 1/(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar\alpha_{t-1}})=\frac{1-\bar\alpha_{t-1}}{1-\bar\alpha_t}\beta_t \hspace{6.97cm}\tag{2-16}\newline
\tilde{\mu_t}(x_t,x_0) = (\frac{\sqrt{\alpha_t}}{\beta_t}x_t+\frac{\sqrt{\bar\alpha_t}}{1-\bar\alpha_t}x_0)/(\frac{\alpha_t}{\beta_t}+\frac{1}{1-\bar\alpha_{t-1}}) = \frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar\alpha_t}x_t+\frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1-\bar\alpha_t}x_0 \tag{2-17}
\end{align*}
β t ~ = 1/ ( β t α t + 1 − α ˉ t − 1 1 ) = 1 − α ˉ t 1 − α ˉ t − 1 β t μ t ~ ( x t , x 0 ) = ( β t α t x t + 1 − α ˉ t α ˉ t x 0 ) / ( β t α t + 1 − α ˉ t − 1 1 ) = 1 − α ˉ t α t ( 1 − α ˉ t − 1 ) x t + 1 − α ˉ t α ˉ t − 1 β t x 0 ( 2-16 ) ( 2-17 )
再根据公式(2-4)我们可以得到:
x 0 = 1 α ˉ t ( x t − 1 − α ˉ t z t ˉ ) (2-18) x_0 = \frac{1}{\sqrt{\bar\alpha_t}}(x_t-\sqrt{1-\bar\alpha_t}\bar{z_t}) \tag{2-18}
x 0 = α ˉ t 1 ( x t − 1 − α ˉ t z t ˉ ) ( 2-18 )
将公式(2-18)带入到公式(2-17)中可以得到:
μ ~ t = 1 α t ( x t − 1 − α t 1 − α ˉ t z t ˉ ) (2-19) \tilde\mu_t = \frac{1}{\sqrt{\alpha_t}}(x_t-\frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}}\bar{z_t}) \tag{2-19}
μ ~ t = α t 1 ( x t − 1 − α ˉ t 1 − α t z t ˉ ) ( 2-19 )
但是在论文中实现的DDPM与我们上面的计算还是有一些小小的不同:
所以按照论文中的伪代码,DDPM的一次降噪过程实际上有如下几步:
将随机采样处的噪声图片x t x_t x t 以及t t t 传入Noise Predicter生成预测的噪声结果ϵ θ ( x t , t ) \epsilon_\theta(x_t,t) ϵ θ ( x t , t )
使用原图x t x_t x t 减去预测噪声的1 − α t 1 − α ˉ t \frac{1-\alpha_t}{\sqrt{1-\bar\alpha_t}} 1 − α ˉ t 1 − α t 倍后再乘以1 α t \frac{1}{\sqrt{\alpha_t}} α t 1 就可以得到比较干净的图了
但是在DDPM中,这里又多了一步,那就是将得到的图再加上了一个σ t \sigma_t σ t 倍的z z z ,其中z z z 也是从标准高斯分布中采样得到的噪声,然后得到了最终的图片。
对于这里为什么要加上σ t z \sigma_tz σ t z ,根据李宏毅老师说的就是为了引入随机性,只有引入随机性结果才会好,实验结果也证明了这点。
ChatGPT也是这样说的:
2.3 Training过程
Training过程如下:
在训练过程中,我们首先从分布q ( x 0 ) q(x_0) q ( x 0 ) 中采样出一个干净的图片x 0 x_0 x 0 ,然后从{ 1 , . . , T } \{1,..,T\} { 1 , .. , T } 中随机采样一个t t t 用于控制原图与添加的噪声的比例,最后再从标准正态分布中随机采样一个噪声ϵ \epsilon ϵ ,这样就完成了采样的工作。
然后我们通过x 0 x_0 x 0 和ϵ \epsilon ϵ 生成有噪声的图x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon x t = α ˉ t x 0 + 1 − α ˉ t ϵ ,之后我们将x t x_t x t 和t t t 输入到一个噪声预测器中,预测得到一个噪声ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) \epsilon_\theta(\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon,t) ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ,然后将预测得到的噪声与我们的原噪音ϵ \epsilon ϵ 进行比较,具体过程如下图所示:
根据上面的分析,我们的目标就是尽可能减小原噪声ϵ \epsilon ϵ 与预测噪声ϵ θ \epsilon_\theta ϵ θ 的差别,所以我们的损失函数就是:
L s i m p l e ( θ ) = 𝔼 t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ] \mathcal{L}_{simple}(\theta) = \char"1D53C_{t,x_0,\epsilon}[||\epsilon-\epsilon_\theta(\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon,t)||^2]
L s im pl e ( θ ) = E t , x 0 , ϵ [ ∣∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ]
三、代码实现
3.1 数据集选择
简单起见,我们使用sklearn库中的make_s_curve函数来生成三维S曲线数据集。
该make_s_curve函数使用数学公式生成S曲线,并返回包含生成的数据和相应目标值的元组。生成的数据是二维NumPy数组(n_samples, 3),其中n_samples为样本数,每一行代表三维空间中的一个点。目标值是一个一维NumPy形状数组(n_samples,),其中包含S曲线中每个点的颜色代码(介于 0 和 1 之间)。
使用make_s_curve函数生成S曲线的示例:
1 2 3 4 5 6 7 8 9 10 from sklearn.datasets import make_s_curveimport matplotlib.pyplot as plt X, y = make_s_curve(n_samples=1000 ) fig = plt.figure() ax = fig.add_subplot(projection='3d' ) ax.scatter(X[:, 0 ], X[:, 1 ], X[:, 2 ], c=y) plt.show()
生成的3D数据集可视化如下:
我们使用make_s_curve函数生成一个包含10000个点的数据集。为了方便起见。我们只取s_curve的第0维和第2维,相当于s_curve的一个截面。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 import matplotlib.pyplot as pltimport numpy as npfrom sklearn.datasets import make_s_curveimport torchs_curve, _ = make_s_curve(10 **4 , noise=0.1 ) s_curve = s_curve[:,[0 ,2 ]]/10.0 print ("shape of s:" , np.shape(s_curve))data = s_curve.T fig,ax = plt.subplots() ax.scatter(*data,color='orange' ,edgecolor='white' ); ax.axis('off' ) dataset = torch.Tensor(s_curve).float () plt.show()
最终生成的数据集如下所示:
3.2 计算超参数
在计算超参数之前,我们先来总结一下我们都有哪些超参数需要计算:
首先是β t \beta_t β t ,它控制了添加的噪声的比例。这里我们取β t ∈ [ 0.00001 , 0.005 ] \beta_t \in [0.00001, 0.005] β t ∈ [ 0.00001 , 0.005 ] 。
然后是α t \alpha_t α t 以及α ˉ \bar\alpha α ˉ ,其中α t = 1 − β t , α ˉ = ∏ i = 1 T α i \alpha_t=1-\beta_t,\bar\alpha=\prod_{i=1}^T\alpha_i α t = 1 − β t , α ˉ = ∏ i = 1 T α i 。
所以我们最后的代码就是:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 num_steps = 100 betas = torch.linspace(-6 , 6 , num_steps) betas = torch.sigmoid(betas)*(0.5e-2 - 1e-5 )+1e-5 alphas = 1 - betas alphas_prod = torch.cumprod(alphas,0 ) alphas_prod_p = torch.cat([torch.tensor([1 ]).float (), alphas_prod[:-1 ]], 0 ) alphas_bar_sqrt = torch.sqrt(alphas_prod) one_minus_alphas_bar_log = torch.log(1 - alphas_prod) one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod) assert alphas.shape==alphas_prod.shape==alphas_prod_p.shape==alphas_bar_sqrt.shape==one_minus_alphas_bar_log.shape==one_minus_alphas_bar_sqrt.shapeprint ("all the same shape" ,betas.shape)
其中:
torch.linspace(start, end, steps=100, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False): 其作用是返回一个tensor张量,这个张量包含了从start到end的等距的steps个数据点;
torch.sigmoid():sigmoid公式:f ( x ) = 1 1 + e − x f(x)=\frac{1}{1+e^{-x}} f ( x ) = 1 + e − x 1 。因为sigmoid函数的值域为[0,1],所以sigmoid(betas)可以将β \beta β 的值平滑地压缩到[0,1]之间,再通过乘以(0.5e-2 - 1e-5) + 1e-5将β \beta β 的值放缩到[1e-5,0.5e-2]之间。
alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0):torch.tensor([1]).float():创建一个标量1(对应α_0时的累乘值,通常初始状态下α_0设为1);alphas_prod[:-1]:取alphas_prod除去最后一个元素的部分,即[α_1 * α_2 * … * α_t];torch.cat([…], 0):将标量1和去掉最后一个元素的alphas_prod拼接在一起。
假设alphas_prod是[α_1, α_1 * α_2, α_1 * α_2 * α_3, …, α_1 * α_2 * … * α_t]。那么alphas_prod_p就是[1, α_1, α_1 * α_2, α_1 * α_2 * α_3, …, α_1 * α_2 * … * α_{t-1}](即去掉最后一个元素,前面拼上一个1)。
3.3 前向传播过程
前向扩散过程中,根据公式:
x t = α ˉ t x 0 + 1 − α ˉ t z t ˉ (3-1) x_t = \sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\bar{z_t} \tag{3-1}
x t = α ˉ t x 0 + 1 − α ˉ t z t ˉ ( 3-1 )
我们可以知道可以基于x 0 x_0 x 0 得到任意时刻t t t 的值x t x_t x t 。
代码如下:
1 2 3 4 5 6 7 8 def q_x (x_0, t ): noise = torch.randn_like(x_0) alphas_t = alphas_bar_sqrt[t] alphas_1_m_t = one_minus_alphas_bar_sqrt[t] return (alphas_t * x_0 + alphas_1_m_t * noise)
注 :torch.randn_like():返回一个与输入张量大小相同的张量,其中填充了均值为0方差为1的正态分布的随机值。
3.4 演示原始数据分布加噪100步后的效果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 num_shows = 20 fig,axs = plt.subplots(2 , 10 , figsize=(28 ,3 )) plt.rc('text' ,color='black' ) for i in range (num_shows): j = i // 10 k = i % 10 q_i = q_x(dataset, torch.tensor([i*num_steps//num_shows])) axs[j,k].scatter(q_i[:,0 ], q_i[:,1 ], color='green' , edgecolor='white' ) axs[j,k].set_axis_off() axs[j,k].set_title('$q(\mathbf{x}_{' +str (i*num_steps//num_shows)+'})$' ) plt.show()
效果如下:
通过图像我们可以看出经过不断加噪之后图像会变得越来越趋于纯噪声图片。
3.5 编写拟合逆扩散过程高斯分布的模型
简单起见,我们使用一个简单的多层感知机(MLP)来实现模型:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 import torchimport torch.nn as nnclass MLPDiffusion (nn.Module): def __init__ (self, n_steps, num_units=128 ): super (MLPDiffusion,self).__init__() self.linears = nn.ModuleList( [ nn.Linear(2 , num_units), nn.ReLU(), nn.Linear(num_units,num_units), nn.ReLU(), nn.Linear(num_units,num_units), nn.ReLU(), nn.Linear(num_units,2 ), ] ) self.step_embeddings = nn.ModuleList( [ nn.Embedding(n_steps,num_units), nn.Embedding(n_steps,num_units), nn.Embedding(n_steps,num_units), ] ) def forward (self, x, t ): for idx, embedding_layer in enumerate (self.step_embeddings): t_embedding = embedding_layer(t) x = self.linears[2 *idx](x) x += t_embedding x = self.linears[2 *idx+1 ](x) x = self.linears[-1 ](x) return x
这里我们构建了一个7层的MLP模型(虽然MLPDiffusion中step_embeddings也定义了3个时间嵌入层(nn.Embedding),但这些不属于标准的"网络层"如线性层或激活层)。它们用于引入时间步的信息,并在每次计算中与线性层的输出相加,而不是单独作为一层计算):
Linear(2, num_units):从2维输入映射到128维。
ReLU():非线性激活。
Linear(num_units, num_units):128维到128维的映射。
ReLU():非线性激活。
Linear(num_units, num_units):再一次128维到128维的映射。
ReLU():非线性激活。
Linear(num_units, 2):输出层,128维压缩回2维。
注 :embedding层可以引入时间步的信息,通过时间步嵌入,模型能根据当前的时间步 t 调整对数据的处理策略。例如:早期时间步时,模型可能只需要做少量修复,因为数据仍然接近原始状态;而晚期时间步时,模型需要更复杂的操作来逆转严重污染的数据。在扩散模型中,时间步嵌入让模型理解数据如何从无噪声逐渐变为噪声化的过程,并帮助模型逐步去噪。
3.6 编写损失函数
在训练过程中,我们通过输入加噪后的图片x t = α ˉ t x 0 + 1 − α ˉ t ϵ x_t=\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon x t = α ˉ t x 0 + 1 − α ˉ t ϵ 和t t t 预测出噪声ϵ θ \epsilon_\theta ϵ θ ,再通过预测得到的噪声与原噪声对比获得损失。
所以我们的损失函数就是:
L s i m p l e ( θ ) = 𝔼 t , x 0 , ϵ [ ∣ ∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ] \mathcal{L}_{simple}(\theta) = \char"1D53C_{t,x_0,\epsilon}[||\epsilon-\epsilon_\theta(\sqrt{\bar\alpha_t}x_0+\sqrt{1-\bar\alpha_t}\epsilon,t)||^2]
L s im pl e ( θ ) = E t , x 0 , ϵ [ ∣∣ ϵ − ϵ θ ( α ˉ t x 0 + 1 − α ˉ t ϵ , t ) ∣ ∣ 2 ]
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 def diffusion_loss_fn (model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps ): batch_size = x_0.shape[0 ] t = torch.randint(0 , n_steps, size=(batch_size//2 ,)) t = torch.cat([t, n_steps-1 -t], dim=0 ) t = t.unsqueeze(-1 ) a = alphas_bar_sqrt[t] aml = one_minus_alphas_bar_sqrt[t] e = torch.randn_like(x_0) x = x_0 * a+e * aml output = model(x, t.squeeze(-1 )) return (e - output).square().mean()
3.7 编写逆扩散采样函数(inference过程)
根据论文中的采样过程:
我们可以根据x t x_t x t 和t t t 计算得到x t − 1 x_{t-1} x t − 1 ,然后一步一步往前推,直到将图像还原到x 0 x_0 x 0 。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 def p_sample_loop (model, shape, n_steps, betas, one_minus_alphas_bar_sqrt ): cur_x = torch.randn(shape) x_seq = [cur_x] for i in reversed (range (n_steps)): cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt) x_seq.append(cur_x) return x_seq def p_sample (model, x, t, betas, one_minus_alphas_bar_sqrt ): t = torch.tensor([t]) coeff = betas[t] / one_minus_alphas_bar_sqrt[t] eps_theta = model(x, t) mean = (1 /(1 -betas[t]).sqrt())*(x-(coeff*eps_theta)) z = torch.randn_like(x) sigma_t = betas[t].sqrt() sample = mean + sigma_t * z return (sample)
3.8 训练模型
训练模型并打印loss及中间重构效果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 print ('Training model...' )batch_size = 128 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True ) num_epoch = 4000 plt.rc('text' ,color='blue' ) model = MLPDiffusion(num_steps) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3 ) for t in range (num_epoch): for idx, batch_x in enumerate (dataloader): loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_steps) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(),1. ) optimizer.step() if (t % 100 == 0 ): print (loss) x_seq = p_sample_loop(model, dataset.shape, num_steps, betas, one_minus_alphas_bar_sqrt) fig, axs = plt.subplots(1 , 10 , figsize=(28 ,3 )) for i in range (1 , 11 ): cur_x = x_seq[i*10 ].detach() axs[i-1 ].scatter(cur_x[:,0 ],cur_x[:,1 ],color='red' ,edgecolor='white' ); axs[i-1 ].set_axis_off(); axs[i-1 ].set_title('$q(\mathbf{x}_{' +str (i*10 )+'})$' )
效果如下(由于图像过多,这里只截取了epoch为0,1000,2000,3000,4000时的图像):
根据图象我们可以看到,随着epoch的不断增加,图像的去噪效果越来越好。
3.9 动态可视化
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 import iofrom PIL import Imageimgs = [] for i in range (100 ): plt.clf() q_i = q_x(dataset,torch.tensor([i])) plt.scatter(q_i[:,0 ],q_i[:,1 ],color='red' ,edgecolor='white' ,s=5 ); plt.axis('off' ); img_buf = io.BytesIO() plt.savefig(img_buf,format ='png' ) img = Image.open (img_buf) imgs.append(img) reverse = [] for i in range (100 ): plt.clf() cur_x = x_seq[i].detach() plt.scatter(cur_x[:,0 ],cur_x[:,1 ],color='red' ,edgecolor='white' ,s=5 ); plt.axis('off' ) img_buf = io.BytesIO() plt.savefig(img_buf,format ='png' ) img = Image.open (img_buf) reverse.append(img) imgs = imgs imgs[0 ].save("diffusion_qian.gif" , format ='GIF' , append_images=imgs, save_all=True , duration=100 , loop=0 ) imgs = reverse imgs[0 ].save("diffusion_ni.gif" , format ='GIF' , append_images=imgs, save_all=True , duration=100 , loop=0 )
前向过程:
逆向过程:
参考资料:
[1] Denoising Diffusion Probabilistic Models
[2] 扩散模型 - Diffusion Model【李宏毅2023】
[3] Probabilistic Diffusion Model概率扩散模型理论与完整PyTorch代码详细解读
[4] 由浅入深了解Diffusion Model
[5] 百度百科