Table of contents
- Source Video: Diffusion Models | Paper Explanation | Math Explained - Outlier
- Code: dome272/Diffusion-Models-pytorch
(2023-08-02)
Idea & Theory
Diffusion model is a generative model, so it learns the distribution of data ๐. (Discrimitive model learns labels. And MLE is a strategy to determine the distribution through parameters ๐ฏ)
The essential idea is to systematically and slowly
destroy structure in a data distribution through
an iterative forward diffusion process. We then
learn a reverse diffusion process that restores
structure in data, yielding a highly flexible and
tractable generative model of the data. [1]
Forward diffusion process:
Sample noise from a normal distributionยน and add it to an image iteratively, until the original distribution of the image has been completely destroyed, becoming the same as the noise distribution.
-
The noise level (mean, variance) of each timestep is scaled by a schedule to avoid the variance explosion along with adding more noise.
-
The image distribution should be destroyed slowly, and the noise redundency at the end stage should be reduced.
-
OpenAI ยณ proposed Cosine schedule in favor of the Linear schedule in DDPMยฒ.
Reverse diffusion process:
Predict the noise of each step.
-
Do not predict the full image in one-shot because that’s intractable and results in worse resultsยน.
-
Predicting the mean of the noise distribution and predicting the noise in the image directly are equivalent, just being parameterized differently ยฒ.
-
Predict noise directly, so it can be subtracted from image.
-
The variance ฯยฒ of the normal distribution followed by the noise can be fixedยฒ. But optimizing it together with the mean, the log-likehood will get improved ยณ.
Architecture
-
DDPM used UNet like model:
- Multi-level downsample + resnet block โ Low-res feature maps โ Upsample to origianl size
- Concatenate RHS feature maps with the LHS feature maps of the same resolution to supplement the location information for features at each pixel.
- Attend some of LHS and RHS feature maps by attention blocks to fuse features further.
- Time embedding is added to each level of feature maps for “identifying” the consistent amount of noise to be predicted during a forward pass at a timestep.
-
OpenAI 2nd paper(4) made improvement by modifying the model in:
- Increase levels, reduce width;
- More attention blocks, more heads;
- BigGAN residual block when downsampling and upsampling;
- Adaptive Group Normalization after each resnet block: GroupNorm + affine transform ( Time embedding * GN + Label embedding)
- Classifier guidance is a separate classifier that helps to generate images of a certain category.
Math Derivation
(2024-04-17)
- VAE and diffuseion model both follows MLE strategy to find the parameter corresponding to the desired data distribution.
- VAE solves the dataset distribution P(๐) by approximating ELBO;
- While diffusion model solves the dataset distribution P(๐) by minimizing the KL Divergence.
(2023-08-04)
VAE
-
VAE also wants to get the distribution of dataset P(๐), and an ๐ฑ is generated by a latent variable ๐ณ.
Therefore, based on Bayes theorem, p(๐ฑ) = p(๐ณ) p(๐ฑ|๐ณ) / p(๐ณ|๐ฑ), where p(๐ณ) is the prior.
And p(๐ณ|๐ฑ) is intractable because in p(๐ณ|๐ฑ) = p(๐ณ)p(๐ฑ|๐ณ) / p(๐ฑ), the p(๐ฑ) can’t be computed through โซf(๐ฑ,๐ณ)d๐ณ since ๐ณ is high dimensional continuous.
-
By introducing an approximated posterior q(๐ณ|๐ฑ), log p(๐ฑ) = ELBO + KL-divergence.
$log p(๐ฑ) = E_{q(๐ณ|๐ฑ)} log \frac{p(๐ฑ,๐ณ)}{q(๐ณ|๐ฑ)} + โซ q(๐ณ|๐ฑ) log \frac{q(๐ณ|๐ฑ)}{p(๐ณ)} d๐ณ$
The KL-divergence can be integrated analytically.
ELBO is an expectation w.r.t q(๐ณ|๐ฑ), which can technically be estimated using Monte Carlo sampling directly.
But when sampled q(๐ณโฑ|๐ฑ) is around 0, the variance of $log\ q(๐ณ|๐ฑ)$ would be high and make its gradint unstable, then cause the optimization difficult. And the function to be estimated $log(\frac{p_ฮธ(๐ฑ,๐ณ)}{q_ฯ(๐ณ|๐ฑ)})$ involves two approximate models containing a lot error.
Thus, it’s not feasible to approximate ELBO directly.
-
To approximate ELBO, we analyse the generative model (Decoder) p(๐ฑ|๐ณ).
Base on Bayes theorem, p(๐ฑ|๐ณ) = p(๐ฑ) p(๐ณ|๐ฑ)/p(๐ณ).
By introducing the posterior approximation q(๐ณ|๐ฑ), p(๐ฑ|๐ณ) can derive: E(log p(๐ฑ|๐ณ)) = ELBO + KL-divergence, i.e.,
$E_{q(๐ณ|๐ฑ)}[ log p(๐ฑ|๐ณ)] = E_{q(๐ณ|๐ฑ)}[ log(\frac{p(๐ณ|๐ฑ) p(๐ฑ)}{q(๐ณ|๐ฑ)})] + โซ q(๐ณ|๐ฑ) log \frac{q(๐ณ|๐ฑ)}{p(๐ณ)} d๐ณ$
Given ๐ณ, the likelihood p(๐ฑ|๐ณ) is supposed to be maximized. (The probiblity that the real ๐ฑ is sampled should be maximum.)
Therefore, the parameters ฮธ of generative model p(๐ฑ|๐ณ) should be optimized via MLE (cross-entropy) loss.
-
Now since ELBO = E(log p(๐ฑ|๐ณ)) - KL-divergence and KL-div is known, ELBO will be obtained by just computing E(log p(๐ฑ|๐ณ)).
E(log p(๐ฑ|๐ณ)) can be estimated by MC: sample a ๐ณ then compute log p(๐ฑ|๐ณ), and repeat N times, take average.
The approximated E(log p(๐ฑ|๐ณ)) should be close to the original ๐ฑ, so there is a MSE loss to optimize the parameters ฯ of the distribution of ๐ณ.
- (2023-10-30) ๐ณ’s distribution needs to be learned as well for sampling ๐ฑ.
But MC sampling is not differentiable, so ฯ cannot be optimized through gradient descent.
Therefore, reparameterization considers that ๐ณ comes from a differentiable determinstic transform of ฮต, a random noise, i.e., ๐ณ = ฮผ + ฯฮต.
Then, parameters (ฮผ, ฯยฒ) of ๐ณ’s distribution (Encoder) will be optimized by MSE.
Forward process
-
The forward diffusion process is like the “Encoder” p(๐ณ|๐ฑ) in VAE:
$$q(๐ณ|๐ฑ) โ q(๐ฑโ | ๐ฑโโโ)$$The distribution of image ๐ฑโ at timestep t is determined by the image ๐ฑโโโ at the previous timestep, where smaller t means less noise.
Specifically, ๐ฑโ follows a normal distribution with a mean of $\sqrt{1-ฮฒโ}๐ฑโโโ$ and a variance of $\sqrt{ฮฒโ}๐$:
$$q(๐ฑโ | ๐ฑโโโ) = N(๐ฑโ; \sqrt{1-ฮฒโ} ๐ฑโโโ, \sqrt{ฮฒโ}๐)$$๐ฑโ is similar to ๐ฑโโโ because its mean is around ๐ฑโโโ.
-
An image ๐ฑ is a “vector”, and each element of it is a pixel.
-
As timestep t increase, ฮฒโ increases and (1-ฮฒโ) decreases, which indicates the variance gets larger and the mean value gets smaller.
Intuitively, the value of the original pixel xโโโ is fading and more pixels become outliers resulting in a wider range of variation around the mean.
-
By introducing a notation $ฮฑ = 1-ฮฒโ$, the t-step evolution from ๐ฑโ to ๐ฑโ can be simplied to a single expression instead of sampling t times iteratively.
-
Replace (1-ฮฒโ) with ฮฑ, the distribution becomes:
$q(๐ฑโ | ๐ฑโโโ) = N(๐ฑโ; \sqrt{ฮฑโ} ๐ฑโโโ, (1-ฮฑโ)๐)$
-
Based on the reparameterization trick, a sample from the distribution is:
$๐ฑโ = \sqrt{ฮฑโ} ๐ฑโโโ + \sqrt{1-ฮฑโ} ฮต$
-
Similarly, $๐ฑโโโ = \sqrt{ฮฑโโโ} ๐ฑโโโ + \sqrt{1-ฮฑโโโ} ฮต$, and plug it into ๐ฑโ.
-
Then $๐ฑโ = \sqrt{ฮฑโโโ} ( \sqrt{ฮฑโ} ๐ฑโโโ + \sqrt{1-ฮฑโโโ} ฮต ) + \sqrt{1-ฮฑโ} ฮต$. Now, the mean becomes $\sqrt{ฮฑโฮฑโโโ} ๐ฑโโโ$
-
Given variance = 1 - (mean/๐ฑ)ยฒ in the above normal distribution $N(๐ฑโ; \sqrt{ฮฑโ} ๐ฑโโโ, (1-ฮฑโ)๐)$, and here mean = $\sqrt{ฮฑโฮฑโโโ} ๐ฑโโโ$,
the standard deviation should be $\sqrt{1 - ฮฑโฮฑโโโ}$, then ๐ฑโ becomes:
$๐ฑโ = \sqrt{ฮฑโฮฑโโโ} ๐ฑโโโ + \sqrt{1 - ฮฑโฮฑโโโ} ฮต$
-
Repeatedly substituting intermediate states, the ๐ฑโ can be represented with ๐ฑโ :
$๐ฑโ = \sqrt{ฮฑโฮฑโโโ ... ฮฑโ} ๐ฑโ + \sqrt{1 - ฮฑโฮฑโโโ ... ฮฑโ} ฮต$
-
Denote the cumulative product “ฮฑโฮฑโโโ … ฮฑโ” as $\bar aโ$, the ๐ฑโ can be reached in one-shot.
$๐ฑโ = \sqrt{\bar aโ} ๐ฑโ + \sqrt{1 - \bar aโ} ฮต$
-
The distribution of ๐ฑโ given ๐ฑโ is:
$q(๐ฑโ | ๐ฑโ) = N(๐ฑโ; \sqrt{\bar aโ} ๐ฑโ, (1 - \bar aโ)๐)$
-
With this expression, the deterministic forward process is ready-to-use and only the reverse process needs to be learned by a network.
- That’s why in the formula below, they “reverse” the forward q(๐ฑโ|๐ฑโโโ) to q(๐ฑโโโ|๐ฑโ) resulting in the equation only containing “reverse process”: ๐ฑโโโ|๐ฑโ, which then can be learned by narrowing the gap between q(๐ฑโโโ|๐ฑโ) and p(๐ฑโโโ|๐ฑโ).
Reverse process
The reverse diffusion process is like the Decoder in VAE.
$$p(๐ฑ|๐ณ) โ p(๐ฑโโโ๐ฑโโโ..๐ฑโ | ๐ฑโ)$$-
Given a noise image ๐ฑโ, the distribution of less-noise image ๐ฑโโโ is
$p(๐ฑโโโ | ๐ฑโ) = N(๐ฑโโโ; ฮผ_ฮธ(๐ฑโ, t), ฮฃ_ฮธ(๐ฑโ, t))$
where the variance can be a fixed schedule as ฮฒโ, so only the mean $ฮผ_ฮธ(๐ฑโ, t)$ needs to be learned with a network.
VLB
VLB is the loss to be minimized. VLB gets simplied by:
-
Applying Bayes rule to “reverse” the direction of the forward process, which becomes “forward denoising” steps q(๐ฑโโโ | ๐ฑโ), because it’s from a noise image to a less-noise image;
-
Adding extra conditioning on ๐ฑโ for each “forward denosing” step q(๐ฑโโโ | ๐ฑโ, ๐ฑโ).
Derivation by step:
-
Diffusion model wants a set of parameter ๐ letting the likelihood of the original image ๐ฑโ maximum.
$$\rm ฮธ = arg max_ฮธ\ log\ p_ฮธ(๐ฑโ)$$With adding a minus sign, the objective turns to find the minimum:
-log p(๐ฑโ) = -ELBO - KL-divergence
$$ \begin{aligned} & -log p(๐ฑโ) \left( = -log \frac{p(๐ฑโ, ๐ณ)}{p(๐ณ|๐ฑโ)} \right) \\\ &= -log \frac{p(๐ฑ_{1:T}, ๐ฑโ)}{p(๐ฑ_{1:T} | ๐ฑโ)} \\\ & \text{(Introduce "approximate posterior" q :)} \\\ &= -(log \frac{ p(๐ฑ_{1:T}, ๐ฑโ) }{ q(๐ฑ_{1:T} | ๐ฑโ)} \ + log (\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{1:T} | ๐ฑโ)}) ) \\\ \end{aligned} $$-
Note that $q(๐ฑ_{1:T} | ๐ฑโ)$ represents a joint distribution of N conditional distributions ๐ฑโ and ๐ฑโโโ.
-
It is the step-by-step design that makes training a network to learn the data distribution possible. Meanwhile, the sampling process also has to be step-by-step.
Compute expection w.r.t. $q(๐ฑ_{1:T} | ๐ฑโ)$ for both side.
$$ E_{q(๐ฑ_{1:T} | ๐ฑโ)} [ -log p(๐ฑโ) ] \\\ \ = E_{q(๐ฑ_{1:T} | ๐ฑโ)} \left[-log \frac{ p(๐ฑ_{0:T}) }{ q(๐ฑ_{1:T} | ๐ฑโ)}\right] \ + E_{q(๐ฑ_{1:T} | ๐ฑโ)} \left[-log (\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{1:T} | ๐ฑโ)})\right] $$Expectation is equivalent to integration.
$$ \begin{aligned} & \text{LHS:} โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * (-log p(๐ฑโ)) d๐ฑ_{1:T} = -log p(๐ฑโ) \\\ & \text{RHS:} \ = E_{q(๐ฑ_{1:T} | ๐ฑโ)} \left[-log \frac{ p(๐ฑ_{0:T}) }{ q(๐ฑ_{1:T} | ๐ฑโ)}\right] \\\ & + โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * \left(-log (\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{1:T} | ๐ฑโ)})\right) d๐ฑ_{1:T} \end{aligned} $$ -
-
Since KL-divergence is non-negative, there is:
-log p(๐ฑโ) โค -log p(๐ฑโ) + KL-divergence =
$$ \begin{aligned} & -log p(๐ฑโ) + D_{KL}( q(๐ฑ_{1:T} | ๐ฑโ) || p(๐ฑ_{1:T} | ๐ฑโ) ) \\\ &= -log p(๐ฑโ) \ + โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * \left(log (\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{1:T} | ๐ฑโ)})\right) d๐ฑ_{1:T} \end{aligned} $$ -
Break apart the denominator $p(๐ฑ_{1:T} | ๐ฑโ)$ of the argument in the KL-divergence’s logarithm based on Bayes rule:
$$p(๐ฑ_{1:T} | ๐ฑโ) = \frac{p(๐ฑ_{1:T}, ๐ฑโ)}{p(๐ฑโ)} = \frac{p(๐ฑ_{0:T})}{p(๐ฑโ)}$$Plug it back to KL-divergence:
$$ \begin{aligned} &โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * \left( log(\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{1:T} | ๐ฑโ)})\right) d๐ฑ_{1:T} \\\ &= โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * log (\frac{q(๐ฑ_{1:T} | ๐ฑโ) p(๐ฑโ)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T} \\\ &= โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * [ log(p(๐ฑโ) + log(\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{0:T})})] d๐ฑ_{1:T}\\\ &= โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * log(p(๐ฑโ) d๐ฑ_{1:T} \\\ &\quad + โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * log(\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T} \\\ &= log p(๐ฑโ) + โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ) * log(\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T} \end{aligned} $$ -
Plug this decomposed KL-divergence into the above inequality, and the incomputable log-likelihood (-log p(๐ฑโ)) can be canceled, resulting in the Variational Lower Bound (VLB):
$$-log p(๐ฑโ) โค โซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ)\ log(\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T}$$The argument of log is a ratio of the forward process and the reverse process.
The numerator is the distribution of $๐ฑ_{1:T}$ given the starting point ๐ฑโ. To make the numerator and denominator have symmetric steps, the starting point of the reverse process $p(๐ฑ_T)$ can be separated out.
-
Separate out $p(๐ฑ_T)$ from the denominator by rewriting the conditional probability as a cumulative product:
$$ p(๐ฑ_{0:T}) = p(๐ฑ_T) ฮ _{t=1}^T p(๐ฑโโโ|๐ฑโ) $$Plug it back into the logarithm of the VLB, and break the numerator joint distribution as a product of N-1 steps as well:
$$ log(\frac{q(๐ฑ_{1:T} | ๐ฑโ)}{p(๐ฑ_T) ฮ _{t=1}^T p(๐ฑโโโ|๐ฑโ)}) \= log \frac{ ฮ _{t=1}^T q(๐ฑโ|๐ฑโโโ)}{ p(๐ฑ_T) ฮ _{t=1}^T p(๐ฑโโโ|๐ฑโ)} \\\ \= log \frac{ ฮ _{t=1}^T q(๐ฑโ|๐ฑโโโ)}{ ฮ _{t=1}^T p(๐ฑโโโ|๐ฑโ)} - log p(๐ฑ_T) \\\ \= โ_{t=1}^T log (\frac{q(๐ฑโ|๐ฑโโโ)}{p(๐ฑโโโ|๐ฑโ)}) - log\ p(๐ฑ_T) $$This form includes every step rather than only focusing on the distribution of the all events $๐ฑ_{1:T}$.
(2023-08-11) DM wants the data distribution, but it doesn’t rebuild the distribution transformation directly from Gaussian to data distribution, but approachs the corruption process step-by-step to reduce the difficulty (variance).
-
Separate the first item (first step, t=1) from the summation, so that the other terms can be conditioned on ๐ฑโ, thus reducing the variance:
$$ log \frac{q(๐ฑโ|๐ฑโ)}{p(๐ฑโ|๐ฑโ)} + โ_{t=2}^T log (\frac{q(๐ฑโ|๐ฑโโโ)}{p(๐ฑโโโ|๐ฑโ)}) - log\ p(๐ฑ_T) $$ -
Reformulate the numerator $q(๐ฑโ|๐ฑโโโ)$ based on Bayes rule:
$$ q(๐ฑโ|๐ฑโโโ) = \frac{q(๐ฑโโโ|๐ฑโ)q(๐ฑโ)}{q(๐ฑโโโ)} $$In this form, forward adding noise $q$ and reverse denoising $p$ become the same process from ๐ฑโ to ๐ฑโโโ. Such that, in one pass, the model can both perform forward process and reverse process once.
-
Make each step conditioned on ๐ฑโ to reduce the variance (uncertainty).
$$ q(๐ฑโ|๐ฑโโโ) = \frac{q(๐ฑโโโ|๐ฑโ, ๐ฑโ)q(๐ฑโ| ๐ฑโ)}{q(๐ฑโโโ| ๐ฑโ)} $$And this distribution $q(๐ฑโโโ|๐ฑโ, ๐ฑโ)$ has a closed-form solution.
Here is why the first step is separated out: If t=1, the $q(๐ฑโ|๐ฑโ)$ conditioned on ๐ฑโ is:
$$ q(๐ฑโ|๐ฑโ) = \frac{q(๐ฑโ|๐ฑโ, ๐ฑโ)q(๐ฑโ|๐ฑโ)}{q(๐ฑโ|๐ฑโ)} $$There is a loop of $q(๐ฑโ|๐ฑโ)$ if ๐ฑโ exists, and other terms $q(๐ฑโ|๐ฑโ, ๐ฑโ)$ and $q(๐ฑโ|๐ฑโ)$ don’t make sense.
Plug the newly conditioned numerator back to the fraction, and break it apart based on log rule:
$$ โ_{t=2}^T log \frac{q(๐ฑโ|๐ฑโโโ)}{p(๐ฑโโโ|๐ฑโ)} \ = โ_{t=2}^T log \frac{q(๐ฑโโโ|๐ฑโ, ๐ฑโ)q(๐ฑโ| ๐ฑโ)}{p(๐ฑโโโ|๐ฑโ)q(๐ฑโโโ| ๐ฑโ)} \\\ \ = โ_{t=2}^T log \frac{q(๐ฑโโโ|๐ฑโ, ๐ฑโ)}{p(๐ฑโโโ|๐ฑโ)} + โ_{t=2}^T log \frac{q(๐ฑโ| ๐ฑโ)}{q(๐ฑโโโ| ๐ฑโ)} \\\ $$The second term will be simplied to $log \frac{q(๐ฑ_T| ๐ฑโ)}{q(๐ฑโ| ๐ฑโ)}$
Then, the variational lower bound becomes:
$$ D_{KL}(q(๐ฑ_{1:T}|๐ฑโ) || p(๐ฑ_{0:T})) = \\\ log \frac{q(๐ฑโ|๐ฑโ)}{p(๐ฑโ|๐ฑโ)} \ + โ_{t=2}^T log \frac{q(๐ฑโโโ|๐ฑโ, ๐ฑโ)}{p(๐ฑโโโ|๐ฑโ)} \ + log \frac{q(๐ฑ_T| ๐ฑโ)}{q(๐ฑโ| ๐ฑโ)} \ - log\ p(๐ฑ_T) \\\ \ \\\ \ = โ_{t=2}^Tlog \frac{ q(๐ฑโโโ|๐ฑโ, ๐ฑโ) }{ p(๐ฑโโโ|๐ฑโ)} \ + log \frac{q(๐ฑ_T| ๐ฑโ)}{p(๐ฑโ|๐ฑโ)} \ - log\ p(๐ฑ_T) \\\ \ \\\ \ = โ_{t=2}^T log \frac{ q(๐ฑโโโ|๐ฑโ, ๐ฑโ) }{ p(๐ฑโโโ|๐ฑโ)} \ + log \frac{q(๐ฑ_T| ๐ฑโ)}{p(๐ฑ_T)} \ - log p(๐ฑโ|๐ฑโ) $$Write this formula as KL-divergence, so that a concrete expression can be determined later.
How are those two fractions written as KL-divergence? $$ \begin{aligned} & โ_{t=2}^T D_{KL} (q(๐ฑโโโ|๐ฑโ, ๐ฑโ) || p(๐ฑโโโ|๐ฑโ)) \\\ & + D_{KL} (q(๐ฑ_T| ๐ฑโ) || p(๐ฑ_T)) \\\ & - log\ p(๐ฑโ|๐ฑโ) \end{aligned} $$
Loss function
The VLB to be minimized is eventually derived as a MSE loss function between the actual noise and the predicted noise.
-
$D_{KL} (q(๐ฑ_T| ๐ฑโ) || p(๐ฑ_T))$ can be ignored.
- $q(๐ฑ_T| ๐ฑโ)$ has no learnable parameters because it just adds noise following a schedule.
- And $p(๐ฑ_T)$ is the noise image sampled from normal distribution. Since $q(๐ฑ_T| ๐ฑโ)$ is the eventual image which is supposed to follow the normal distribution, this KL-divergence should be small.
Then, the loss only contains the other two terms:
$$L = โ_{t=2}^T D_{KL} (q(๐ฑโโโ|๐ฑโ, ๐ฑโ) || p(๐ฑโโโ|๐ฑโ)) - log\ p(๐ฑโ|๐ฑโ)$$ -
$D_{KL} (q(๐ฑโโโ|๐ฑโ, ๐ฑโ) || p(๐ฑโโโ|๐ฑโ))$ is the MSE between the actual noise and the predicted noise.
-
For the reverse pass, the distribution of the denoised image $p(๐ฑโโโ|๐ฑโ)$ has a parametric expression:
$$p(๐ฑโโโ|๐ฑโ) = N(๐ฑโโโ; ฮผ_ฮธ(๐ฑโ,t), ฮฃ_ฮธ(๐ฑโ,t)) \\\ = N(๐ฑโโโ; ฮผ_ฮธ(๐ฑโ,t), ฮฒ๐)$$where ฮฃ is fixed as ฮฒโ๐, and only the mean $ฮผ_ฮธ(๐ฑโ,t)$ will be learned and represented by a network (output) through the MSE loss of noise as below.
-
For the (“reversed”) forward pass, the distribution of noise-added image $q(๐ฑโโโ|๐ฑโ, ๐ฑโ)$ has a closed-form solution, which can be written as a similar expression as p(๐ฑโโโ|๐ฑโ): What’s the derivation?
$$ q(๐ฑโโโ|๐ฑโ, ๐ฑโ) = N(๐ฑโโโ; \tilde ฮผโ(๐ฑโ,๐ฑโ), \tilde ฮฒโ๐) \\\ \ \\\ \tilde ฮฒโ = \frac{1- \bar ฮฑโโโ}{1-\bar ฮฑโ} โ ฮฒโ \\\ \ \\\ \tilde ฮผ(๐ฑโ,๐ฑโ) = \frac{\sqrt{ฮฑโ} (1-\bar ฮฑโโโ) }{1-\bar ฮฑโ} ๐ฑโ \ + \frac{\sqrt{\bar ฮฑโโโ} ฮฒโ}{1-\bar ฮฑโ} ๐ฑโ \\\ \ \\\ \rm Is there \sqrt{ฮฑโ} or \sqrt{\bar ฮฑโ} ? $$where the $\tilde ฮฒโ$ is fixed, so only consider the $\tilde ฮผ(๐ฑโ,๐ฑโ)$, which can be simplified by the one-step forward process expression: $๐ฑโ = \sqrt{\bar ฮฑโ} ๐ฑโ + \sqrt{1 - \bar ฮฑโ} ฮต$
$$ ๐ฑโ = \frac{๐ฑโ - \sqrt{1 - \bar ฮฑโ} ฮต}{\sqrt{\bar ฮฑโ}} $$Plug ๐ฑโ into $\tilde ฮผ(๐ฑโ,๐ฑโ)$, then the mean of the noise-added image doesn’t depend on ๐ฑโ anymore:
$$ \begin{aligned} \tilde ฮผ(๐ฑโ,๐ฑโ) & = \frac{\sqrt{ฮฑโ} (1-\bar ฮฑโโโ) }{1-\bar ฮฑโ} ๐ฑโ \ + \frac{\sqrt{\bar ฮฑโโโ} ฮฒโ}{1-\bar ฮฑโ} \ \frac{๐ฑโ - \sqrt{1 - \bar ฮฑโ} ฮต}{\sqrt{\bar ฮฑโ}} \\\ \ \\\ & = ???\ How \ to \ do? \ ??? \\\ & = \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต) \end{aligned} $$The mean of the distribution from which the noise-added image (๐ฑโ,๐ฑโ) at timestep t get sampled out is subtracting some random noise from image ๐ฑโ.
๐ฑโ is known from the forward process schedule, and the $\tilde ฮผ(๐ฑโ,๐ฑโ)$ is the target for the network to optimize weights to make the predicted mean $ฮผ_ฮธ(๐ฑโ,t)$ same as $\tilde ฮผ(๐ฑโ,๐ฑโ)$.
-
Since network only output ฮผ, the KL-divergence in the loss function can be simplified in favor of using MSE:
$$ Lโ = \frac{1}{2ฯโยฒ} \\| \tilde ฮผ(๐ฑโ,๐ฑโ) - ฮผ_ฮธ(๐ฑโ,t) \\|ยฒ $$This MSE indicates that the noise-added image in the forward process and the noise-removed image in the reverse process should be as close as possible.
Since the actual mean $\tilde ฮผ(๐ฑโ,๐ฑโ) = \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต)$, where ๐ฑโ is known, as it’s the input to the network. So the model is essentially estimating the actual $ฮต$ (random noise) every time.
Hence, the predicted mean $ฮผ_ฮธ(๐ฑโ,t)$ by the model can be written in the same form as $\tilde ฮผ(๐ฑโ,๐ฑโ)$, where only the noise $ฮต_ฮธ$ has parameters:
$$ฮผ_ฮธ(๐ฑโ,t) = \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต_ฮธ((๐ฑโ,t)))$$Therefore, the loss term becomes:
$$ Lโ = \frac{1}{2ฯโยฒ} \\| \tilde ฮผ(๐ฑโ,๐ฑโ) - ฮผ_ฮธ(๐ฑโ,t) \\|ยฒ \\\ \ = \frac{1}{2ฯโยฒ} \left\\| \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต) \ - \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต_ฮธ(๐ฑโ,t)) \right\\|ยฒ \\\ \ = \frac{ฮฒโยฒ}{2ฯโยฒ ฮฑโ (1-\bar ฮฑโ)} \\|ฮต - ฮต_ฮธ(๐ฑโ,t) \\|ยฒ $$Disregarding the scaling factor can bring better sampling quality and easier implementation, so the final loss for the KL-divergence is MSE between actual noise and predicted noise at time t:
$$\\|ฮต - ฮต_ฮธ(๐ฑโ,t) \\|ยฒ$$ -
Once the mean $ฮผ_ฮธ(๐ฑโ,t)$ has predicted out based on ๐ฑโ and t, a “cleaner” image can be sampled from the distribution:
$$ N(๐ฑโโโ; ฮผ_ฮธ(๐ฑโ,t), \sigma_ฮธ(๐ฑโ,t)) = N(๐ฑโโโ; \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต_ฮธ(๐ฑโ,t), ฮฒโ๐) $$ By using reparameterization trick, this sampled image is: $$ ๐ฑโโโ = ฮผ_ฮธ(๐ฑโ,t) + ฯฮต \ = \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต_ฮธ(๐ฑโ,t) + \sqrt{ฮฒโ}ฮต $$
-
-
The last term $log p(๐ฑโ|๐ฑโ)$ in the VLB is the predicted distribution for the original image ๐ฑโ. Its goodness is measured by a probability that the original image $๐ฑโ$ gets sampled from the estimated distribution $N(x; ฮผ_ฮธโฑ(๐ฑโ,1), ฮฒโ)$.
The probability of an image should be a product of total D pixels. And the probability a pixel should be an integral over an interval [ฮดโ, ฮดโ] of the PDF curve:
$$ p_ฮธ(๐ฑโ|๐ฑโ) = โ_{i=1}^D โซ_{ฮดโ(xโโฑ)}^{ฮดโ(xโโฑ)} N(x; ฮผ_ฮธโฑ(๐ฑโ,1), ฮฒโ) dx $$- where $xโ$ is the pixel’s ground-truth.
- $N(x; ฮผ_ฮธโฑ(๐ฑโ,1), ฮฒโ)$ is the distribution to be integrated.
This interval is determined based on the actual pixel value as:
$$ ฮดโ(x) = \begin{cases} โ & \text{if x = 1} \\\ x+\frac{1}{255} & \text{if x < 1} \end{cases}, \quad ฮดโ(x) = \begin{cases} -โ & \text{if x = -1} \\\ x-\frac{1}{255} & \text{if x > -1} \end{cases} $$-
The original pixel range [0,255] has been normalized to [-1, 1] to align with the standard normal distribution $p(x_T) \sim N(0,1)$
-
If the actual value is 1, the integral upper bound in the distribution is โ, and the lower bound is 1-1/255 = 0.996, the width of the interval is from 0.996 to infinity.
If the actual value is 0.5, the upper bound is 0.5+1/255, and the lower bound is 0.5-1/255, the width of the interval is 2/255.
pic: area of the true pixel region in two predicted distributions.
If the area around the actual pixel value under the predicted distribution PDF curve is large, the predicted distribution is good. Howerver, if the area around real pixel value is small, the estimated mean is wrongly located.
Hence, this probability (log-likelihood) should be maximized, and by condering the minus sign in front of it, the corresponding loss term comes.
However, the authors got rid of this loss term $-log p(๐ฑโ|๐ฑโ)$ when training the network. And the consequense is at inference time, the final step from ๐ฑโ to ๐ฑโ doesn’t add noise, because this step wasn’t get optimized. Therefore, The difference from other sampling steps is that the predicted ๐ฑโ doesn’t plus random noise.
$$ \begin{aligned} \text{t>1:}\quad ๐ฑ_{t-1} &= \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต_ฮธ(๐ฑโ,t) + \sqrt ฮฒโ ฮต) \\\ \text{t=1:}\quad ๐ฑ_{t-1} &= \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{ฮฒโ}{\sqrt{1 - \bar ฮฑโ}} ฮต_ฮธ(๐ฑโ,t)) \end{aligned} $$A simple reason is that we don’t want to add noise to the final denoised clear output image ๐ฑโ. Otherwise, the generated image is low-quality.
The complete loss function is MSE:
$$ \begin{aligned} \rm L_{simple} &= E_{t,๐ฑโ,ฮต} [ || ฮต - ฮต_ฮธ(๐ฑโ,t)||ยฒ ] \\\ &= E_{t,๐ฑโ,ฮต} [ || ฮต - ฮต_ฮธ( \sqrt{\bar aโ} ๐ฑโ + \sqrt{1 - \bar aโ} ฮต, t) ||ยฒ ] \end{aligned} $$- t is sampled from a uniform distribution between 1 and t;
- ๐ฑโ is the one-step forward process.
Algorithms
DDPM paper
Training a model:
\begin{algorithm}
\caption{Training}
\begin{algorithmic}
\REPEAT
\STATE Sample a t from U(0,T)
\STATE Select an input image ๐ฑโ from dataset
\STATE Sample a noise from N(0,๐)
\STATE Perform gradient descent with loss: \\\\
$||ฮต - ฮต_ฮธ(\sqrt{\bar aโ} ๐ฑโ + \sqrt{1 - \bar aโ} ฮต, t)||ยฒ$
\UNTIL{converge}
\end{algorithmic}
\end{algorithm}
Sampling from the learned data distribution by means of reparameterization trick:
\begin{algorithm}
\caption{Sampling}
\begin{algorithmic}
\STATE Sample a noise image $๐ฑ_T \sim N(0,๐)$
\FOR{t = T:1}
\COMMENT{Remove noise step-by-step}
\IF{t=1}
\STATE ฮต=0
\ELSE
\STATE ฮต ~ N(0,๐)
\ENDIF
\STATE $๐ฑโโโ = \frac{1}{\sqrt{ฮฑโ}} (๐ฑโ - \frac{1-ฮฑโ}{\sqrt{1 - \bar ฮฑโ}} ฮต_ฮธ(๐ฑโ,t) + \sqrt{ฯโ}ฮต$
\COMMENT{Reparam trick}
\ENDFOR
\RETURN ๐ฑโ
\end{algorithmic}
\end{algorithm}
In this reparametrization formula change ฮฒโ and $\sqrt{ฮฒโ}$ to 1-ฮฑโ and ฯโ, which are different from the above equation.
Training and Sampling share the common pipeline:
Improvements
Improvements from OpenAI’s 2021 papers.
-
Learn a scale factor for interpolating the upper and lower bound to get a flexible variance:
$$ฮฃ_ฮธ(xโ,t) = exp(v\ log ฮฒโ +(1-v)\ log(1- \tilde{ฮฒโ}))$$v is learned by adding an extra loss term $ฮป L_{VLB}$, and ฮป=0.001.
$$L_{hybrid} = E_{t,๐ฑโ,ฮต} [ || ฮต - ฮต_ฮธ(๐ฑโ,t)||ยฒ ] + ฮป L_{VLB}$$ -
Use cosine noise schedule $f(t)=cos(\frac{t/T+s}{1+s}โ ฯ/2)ยฒ$ in favor of linear schedule.