watch: Diffusion - Outlier | Explain 4 papers

Table of contents


(2023-08-02)

Idea & Theory

Diffusion model is a generative model, so it learns the distribution of data ๐—. (Discrimitive model learns labels. And MLE is a strategy to determine the distribution through parameters ๐šฏ)

The essential idea is to systematically and slowly
destroy structure in a data distribution through
an iterative forward diffusion process. We then
learn a reverse diffusion process that restores
structure in data, yielding a highly flexible and
tractable generative model of the data. [1]

Forward diffusion process:

Sample noise from a normal distributionยน and add it to an image iteratively, until the original distribution of the image has been completely destroyed, becoming the same as the noise distribution.

  • The noise level (mean, variance) of each timestep is scaled by a schedule to avoid the variance explosion along with adding more noise.

  • The image distribution should be destroyed slowly, and the noise redundency at the end stage should be reduced.

  • OpenAI ยณ proposed Cosine schedule in favor of the Linear schedule in DDPMยฒ.

Reverse diffusion process:

Predict the noise of each step.

  • Do not predict the full image in one-shot because that’s intractable and results in worse resultsยน.

  • Predicting the mean of the noise distribution and predicting the noise in the image directly are equivalent, just being parameterized differently ยฒ.

  • Predict noise directly, so it can be subtracted from image.

  • The variance ฯƒยฒ of the normal distribution followed by the noise can be fixedยฒ. But optimizing it together with the mean, the log-likehood will get improved ยณ.


Architecture

  1. DDPM used UNet like model:

    • Multi-level downsample + resnet block โž” Low-res feature maps โž” Upsample to origianl size
    • Concatenate RHS feature maps with the LHS feature maps of the same resolution to supplement the location information for features at each pixel.
    • Attend some of LHS and RHS feature maps by attention blocks to fuse features further.
    • Time embedding is added to each level of feature maps for “identifying” the consistent amount of noise to be predicted during a forward pass at a timestep.
  2. OpenAI 2nd paper(4) made improvement by modifying the model in:

    1. Increase levels, reduce width;
    2. More attention blocks, more heads;
    3. BigGAN residual block when downsampling and upsampling;
    4. Adaptive Group Normalization after each resnet block: GroupNorm + affine transform ( Time embedding * GN + Label embedding)
    5. Classifier guidance is a separate classifier that helps to generate images of a certain category.

Math Derivation

(2024-04-17)

  • VAE and diffuseion model both follows MLE strategy to find the parameter corresponding to the desired data distribution.
  • VAE solves the dataset distribution P(๐—) by approximating ELBO;
  • While diffusion model solves the dataset distribution P(๐—) by minimizing the KL Divergence.

(2023-08-04)

VAE

  1. VAE also wants to get the distribution of dataset P(๐—), and an ๐ฑ is generated by a latent variable ๐ณ.

    Therefore, based on Bayes theorem, p(๐ฑ) = p(๐ณ) p(๐ฑ|๐ณ) / p(๐ณ|๐ฑ), where p(๐ณ) is the prior.

    And p(๐ณ|๐ฑ) is intractable because in p(๐ณ|๐ฑ) = p(๐ณ)p(๐ฑ|๐ณ) / p(๐ฑ), the p(๐ฑ) can’t be computed through โˆซf(๐ฑ,๐ณ)d๐ณ since ๐ณ is high dimensional continuous.

  2. By introducing an approximated posterior q(๐ณ|๐ฑ), log p(๐ฑ) = ELBO + KL-divergence.

    $log p(๐ฑ) = E_{q(๐ณ|๐ฑ)} log \frac{p(๐ฑ,๐ณ)}{q(๐ณ|๐ฑ)} + โˆซ q(๐ณ|๐ฑ) log \frac{q(๐ณ|๐ฑ)}{p(๐ณ)} d๐ณ$

    The KL-divergence can be integrated analytically.

    ELBO is an expectation w.r.t q(๐ณ|๐ฑ), which can technically be estimated using Monte Carlo sampling directly.

    But when sampled q(๐ณโฑ|๐ฑ) is around 0, the variance of $log\ q(๐ณ|๐ฑ)$ would be high and make its gradint unstable, then cause the optimization difficult. And the function to be estimated $log(\frac{p_ฮธ(๐ฑ,๐ณ)}{q_ฯ†(๐ณ|๐ฑ)})$ involves two approximate models containing a lot error.

    Thus, it’s not feasible to approximate ELBO directly.

  3. To approximate ELBO, we analyse the generative model (Decoder) p(๐ฑ|๐ณ).

    Base on Bayes theorem, p(๐ฑ|๐ณ) = p(๐ฑ) p(๐ณ|๐ฑ)/p(๐ณ).

    By introducing the posterior approximation q(๐ณ|๐ฑ), p(๐ฑ|๐ณ) can derive: E(log p(๐ฑ|๐ณ)) = ELBO + KL-divergence, i.e.,

    $E_{q(๐ณ|๐ฑ)}[ log p(๐ฑ|๐ณ)] = E_{q(๐ณ|๐ฑ)}[ log(\frac{p(๐ณ|๐ฑ) p(๐ฑ)}{q(๐ณ|๐ฑ)})] + โˆซ q(๐ณ|๐ฑ) log \frac{q(๐ณ|๐ฑ)}{p(๐ณ)} d๐ณ$

    Given ๐ณ, the likelihood p(๐ฑ|๐ณ) is supposed to be maximized. (The probiblity that the real ๐ฑ is sampled should be maximum.)

    Therefore, the parameters ฮธ of generative model p(๐ฑ|๐ณ) should be optimized via MLE (cross-entropy) loss.

  4. Now since ELBO = E(log p(๐ฑ|๐ณ)) - KL-divergence and KL-div is known, ELBO will be obtained by just computing E(log p(๐ฑ|๐ณ)).

    E(log p(๐ฑ|๐ณ)) can be estimated by MC: sample a ๐ณ then compute log p(๐ฑ|๐ณ), and repeat N times, take average.

    The approximated E(log p(๐ฑ|๐ณ)) should be close to the original ๐ฑ, so there is a MSE loss to optimize the parameters ฯ• of the distribution of ๐ณ.

    • (2023-10-30) ๐ณ’s distribution needs to be learned as well for sampling ๐ฑ.

    But MC sampling is not differentiable, so ฯ• cannot be optimized through gradient descent.

    Therefore, reparameterization considers that ๐ณ comes from a differentiable determinstic transform of ฮต, a random noise, i.e., ๐ณ = ฮผ + ฯƒฮต.

    Then, parameters (ฮผ, ฯƒยฒ) of ๐ณ’s distribution (Encoder) will be optimized by MSE.

Forward process

  • The forward diffusion process is like the “Encoder” p(๐ณ|๐ฑ) in VAE:

    $$q(๐ณ|๐ฑ) โ‡’ q(๐ฑโ‚œ | ๐ฑโ‚œโ‚‹โ‚)$$

    The distribution of image ๐ฑโ‚œ at timestep t is determined by the image ๐ฑโ‚œโ‚‹โ‚ at the previous timestep, where smaller t means less noise.

    Specifically, ๐ฑโ‚œ follows a normal distribution with a mean of $\sqrt{1-ฮฒโ‚œ}๐ฑโ‚œโ‚‹โ‚$ and a variance of $\sqrt{ฮฒโ‚œ}๐ˆ$:

    $$q(๐ฑโ‚œ | ๐ฑโ‚œโ‚‹โ‚) = N(๐ฑโ‚œ; \sqrt{1-ฮฒโ‚œ} ๐ฑโ‚œโ‚‹โ‚, \sqrt{ฮฒโ‚œ}๐ˆ)$$

    ๐ฑโ‚œ is similar to ๐ฑโ‚œโ‚‹โ‚ because its mean is around ๐ฑโ‚œโ‚‹โ‚.

  • An image ๐ฑ is a “vector”, and each element of it is a pixel.

  • As timestep t increase, ฮฒโ‚œ increases and (1-ฮฒโ‚œ) decreases, which indicates the variance gets larger and the mean value gets smaller.

    Intuitively, the value of the original pixel xโ‚œโ‚‹โ‚ is fading and more pixels become outliers resulting in a wider range of variation around the mean.

  • By introducing a notation $ฮฑ = 1-ฮฒโ‚œ$, the t-step evolution from ๐ฑโ‚€ to ๐ฑโ‚œ can be simplied to a single expression instead of sampling t times iteratively.

    • Replace (1-ฮฒโ‚œ) with ฮฑ, the distribution becomes:

      $q(๐ฑโ‚œ | ๐ฑโ‚œโ‚‹โ‚) = N(๐ฑโ‚œ; \sqrt{ฮฑโ‚œ} ๐ฑโ‚œโ‚‹โ‚, (1-ฮฑโ‚œ)๐ˆ)$

    • Based on the reparameterization trick, a sample from the distribution is:

      $๐ฑโ‚œ = \sqrt{ฮฑโ‚œ} ๐ฑโ‚œโ‚‹โ‚ + \sqrt{1-ฮฑโ‚œ} ฮต$

    • Similarly, $๐ฑโ‚œโ‚‹โ‚ = \sqrt{ฮฑโ‚œโ‚‹โ‚} ๐ฑโ‚œโ‚‹โ‚‚ + \sqrt{1-ฮฑโ‚œโ‚‹โ‚} ฮต$, and plug it into ๐ฑโ‚œ.

    • Then $๐ฑโ‚œ = \sqrt{ฮฑโ‚œโ‚‹โ‚} ( \sqrt{ฮฑโ‚œ} ๐ฑโ‚œโ‚‹โ‚‚ + \sqrt{1-ฮฑโ‚œโ‚‹โ‚} ฮต ) + \sqrt{1-ฮฑโ‚œ} ฮต$. Now, the mean becomes $\sqrt{ฮฑโ‚œฮฑโ‚œโ‚‹โ‚} ๐ฑโ‚œโ‚‹โ‚‚$

    • Given variance = 1 - (mean/๐ฑ)ยฒ in the above normal distribution $N(๐ฑโ‚œ; \sqrt{ฮฑโ‚œ} ๐ฑโ‚œโ‚‹โ‚, (1-ฮฑโ‚œ)๐ˆ)$, and here mean = $\sqrt{ฮฑโ‚œฮฑโ‚œโ‚‹โ‚} ๐ฑโ‚œโ‚‹โ‚‚$,

      the standard deviation should be $\sqrt{1 - ฮฑโ‚œฮฑโ‚œโ‚‹โ‚}$, then ๐ฑโ‚œ becomes:

      $๐ฑโ‚œ = \sqrt{ฮฑโ‚œฮฑโ‚œโ‚‹โ‚} ๐ฑโ‚œโ‚‹โ‚‚ + \sqrt{1 - ฮฑโ‚œฮฑโ‚œโ‚‹โ‚} ฮต$

    • Repeatedly substituting intermediate states, the ๐ฑโ‚œ can be represented with ๐ฑโ‚€ :

      $๐ฑโ‚œ = \sqrt{ฮฑโ‚œฮฑโ‚œโ‚‹โ‚ ... ฮฑโ‚} ๐ฑโ‚€ + \sqrt{1 - ฮฑโ‚œฮฑโ‚œโ‚‹โ‚ ... ฮฑโ‚} ฮต$

    • Denote the cumulative product “ฮฑโ‚œฮฑโ‚œโ‚‹โ‚ … ฮฑโ‚” as $\bar aโ‚œ$, the ๐ฑโ‚œ can be reached in one-shot.

      $๐ฑโ‚œ = \sqrt{\bar aโ‚œ} ๐ฑโ‚€ + \sqrt{1 - \bar aโ‚œ} ฮต$

    • The distribution of ๐ฑโ‚œ given ๐ฑโ‚€ is:

      $q(๐ฑโ‚œ | ๐ฑโ‚€) = N(๐ฑโ‚œ; \sqrt{\bar aโ‚œ} ๐ฑโ‚€, (1 - \bar aโ‚œ)๐ˆ)$

With this expression, the deterministic forward process is ready-to-use and only the reverse process needs to be learned by a network.

  • That’s why in the formula below, they “reverse” the forward q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚) to q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ) resulting in the equation only containing “reverse process”: ๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, which then can be learned by narrowing the gap between q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ) and p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ).

Reverse process

The reverse diffusion process is like the Decoder in VAE.

$$p(๐ฑ|๐ณ) โ‡’ p(๐ฑโ‚œโ‚‹โ‚๐ฑโ‚œโ‚‹โ‚‚..๐ฑโ‚€ | ๐ฑโ‚œ)$$
  • Given a noise image ๐ฑโ‚œ, the distribution of less-noise image ๐ฑโ‚œโ‚‹โ‚ is

    $p(๐ฑโ‚œโ‚‹โ‚ | ๐ฑโ‚œ) = N(๐ฑโ‚œโ‚‹โ‚; ฮผ_ฮธ(๐ฑโ‚œ, t), ฮฃ_ฮธ(๐ฑโ‚œ, t))$

    where the variance can be a fixed schedule as ฮฒโ‚œ, so only the mean $ฮผ_ฮธ(๐ฑโ‚œ, t)$ needs to be learned with a network.

VLB

VLB is the loss to be minimized. VLB gets simplied by:

  1. Applying Bayes rule to “reverse” the direction of the forward process, which becomes “forward denoising” steps q(๐ฑโ‚œโ‚‹โ‚ | ๐ฑโ‚œ), because it’s from a noise image to a less-noise image;

  2. Adding extra conditioning on ๐ฑโ‚€ for each “forward denosing” step q(๐ฑโ‚œโ‚‹โ‚ | ๐ฑโ‚œ, ๐ฑโ‚€).

Derivation by step:

  • Diffusion model wants a set of parameter ๐›‰ letting the likelihood of the original image ๐ฑโ‚€ maximum.

    $$\rm ฮธ = arg max_ฮธ\ log\ p_ฮธ(๐ฑโ‚€)$$

    With adding a minus sign, the objective turns to find the minimum:

    -log p(๐ฑโ‚€) = -ELBO - KL-divergence

    $$ \begin{aligned} & -log p(๐ฑโ‚€) \left( = -log \frac{p(๐ฑโ‚€, ๐ณ)}{p(๐ณ|๐ฑโ‚€)} \right) \\\ &= -log \frac{p(๐ฑ_{1:T}, ๐ฑโ‚€)}{p(๐ฑ_{1:T} | ๐ฑโ‚€)} \\\ & \text{(Introduce "approximate posterior" q :)} \\\ &= -(log \frac{ p(๐ฑ_{1:T}, ๐ฑโ‚€) }{ q(๐ฑ_{1:T} | ๐ฑโ‚€)} \ + log (\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{1:T} | ๐ฑโ‚€)}) ) \\\ \end{aligned} $$
    • Note that $q(๐ฑ_{1:T} | ๐ฑโ‚€)$ represents a joint distribution of N conditional distributions ๐ฑโ‚œ and ๐ฑโ‚œโ‚‹โ‚.

    • It is the step-by-step design that makes training a network to learn the data distribution possible. Meanwhile, the sampling process also has to be step-by-step.

    Compute expection w.r.t. $q(๐ฑ_{1:T} | ๐ฑโ‚€)$ for both side.

    $$ E_{q(๐ฑ_{1:T} | ๐ฑโ‚€)} [ -log p(๐ฑโ‚€) ] \\\ \ = E_{q(๐ฑ_{1:T} | ๐ฑโ‚€)} \left[-log \frac{ p(๐ฑ_{0:T}) }{ q(๐ฑ_{1:T} | ๐ฑโ‚€)}\right] \ + E_{q(๐ฑ_{1:T} | ๐ฑโ‚€)} \left[-log (\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{1:T} | ๐ฑโ‚€)})\right] $$

    Expectation is equivalent to integration.

    $$ \begin{aligned} & \text{LHS:} โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * (-log p(๐ฑโ‚€)) d๐ฑ_{1:T} = -log p(๐ฑโ‚€) \\\ & \text{RHS:} \ = E_{q(๐ฑ_{1:T} | ๐ฑโ‚€)} \left[-log \frac{ p(๐ฑ_{0:T}) }{ q(๐ฑ_{1:T} | ๐ฑโ‚€)}\right] \\\ & + โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * \left(-log (\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{1:T} | ๐ฑโ‚€)})\right) d๐ฑ_{1:T} \end{aligned} $$
  • Since KL-divergence is non-negative, there is:

    -log p(๐ฑโ‚€) โ‰ค -log p(๐ฑโ‚€) + KL-divergence =

    $$ \begin{aligned} & -log p(๐ฑโ‚€) + D_{KL}( q(๐ฑ_{1:T} | ๐ฑโ‚€) || p(๐ฑ_{1:T} | ๐ฑโ‚€) ) \\\ &= -log p(๐ฑโ‚€) \ + โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * \left(log (\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{1:T} | ๐ฑโ‚€)})\right) d๐ฑ_{1:T} \end{aligned} $$
  • Break apart the denominator $p(๐ฑ_{1:T} | ๐ฑโ‚€)$ of the argument in the KL-divergence’s logarithm based on Bayes rule:

    $$p(๐ฑ_{1:T} | ๐ฑโ‚€) = \frac{p(๐ฑ_{1:T}, ๐ฑโ‚€)}{p(๐ฑโ‚€)} = \frac{p(๐ฑ_{0:T})}{p(๐ฑโ‚€)}$$

    Plug it back to KL-divergence:

    $$ \begin{aligned} &โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * \left( log(\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{1:T} | ๐ฑโ‚€)})\right) d๐ฑ_{1:T} \\\ &= โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * log (\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€) p(๐ฑโ‚€)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T} \\\ &= โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * [ log(p(๐ฑโ‚€) + log(\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{0:T})})] d๐ฑ_{1:T}\\\ &= โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * log(p(๐ฑโ‚€) d๐ฑ_{1:T} \\\ &\quad + โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * log(\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T} \\\ &= log p(๐ฑโ‚€) + โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€) * log(\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T} \end{aligned} $$
  • Plug this decomposed KL-divergence into the above inequality, and the incomputable log-likelihood (-log p(๐ฑโ‚€)) can be canceled, resulting in the Variational Lower Bound (VLB):

    $$-log p(๐ฑโ‚€) โ‰ค โˆซ_{๐ฑ_{1:T}} q(๐ฑ_{1:T} | ๐ฑโ‚€)\ log(\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_{0:T})}) d๐ฑ_{1:T}$$

    The argument of log is a ratio of the forward process and the reverse process.

    The numerator is the distribution of $๐ฑ_{1:T}$ given the starting point ๐ฑโ‚€. To make the numerator and denominator have symmetric steps, the starting point of the reverse process $p(๐ฑ_T)$ can be separated out.

  • Separate out $p(๐ฑ_T)$ from the denominator by rewriting the conditional probability as a cumulative product:

    $$ p(๐ฑ_{0:T}) = p(๐ฑ_T) ฮ _{t=1}^T p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ) $$

    Plug it back into the logarithm of the VLB, and break the numerator joint distribution as a product of N-1 steps as well:

    $$ log(\frac{q(๐ฑ_{1:T} | ๐ฑโ‚€)}{p(๐ฑ_T) ฮ _{t=1}^T p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)}) \= log \frac{ ฮ _{t=1}^T q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚)}{ p(๐ฑ_T) ฮ _{t=1}^T p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)} \\\ \= log \frac{ ฮ _{t=1}^T q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚)}{ ฮ _{t=1}^T p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)} - log p(๐ฑ_T) \\\ \= โˆ‘_{t=1}^T log (\frac{q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚)}{p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)}) - log\ p(๐ฑ_T) $$

    This form includes every step rather than only focusing on the distribution of the all events $๐ฑ_{1:T}$.

    (2023-08-11) DM wants the data distribution, but it doesn’t rebuild the distribution transformation directly from Gaussian to data distribution, but approachs the corruption process step-by-step to reduce the difficulty (variance).

  • Separate the first item (first step, t=1) from the summation, so that the other terms can be conditioned on ๐ฑโ‚€, thus reducing the variance:

    $$ log \frac{q(๐ฑโ‚|๐ฑโ‚€)}{p(๐ฑโ‚€|๐ฑโ‚)} + โˆ‘_{t=2}^T log (\frac{q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚)}{p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)}) - log\ p(๐ฑ_T) $$
  • Reformulate the numerator $q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚)$ based on Bayes rule:

    $$ q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚) = \frac{q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)q(๐ฑโ‚œ)}{q(๐ฑโ‚œโ‚‹โ‚)} $$

    In this form, forward adding noise $q$ and reverse denoising $p$ become the same process from ๐ฑโ‚œ to ๐ฑโ‚œโ‚‹โ‚. Such that, in one pass, the model can both perform forward process and reverse process once.

  • Make each step conditioned on ๐ฑโ‚€ to reduce the variance (uncertainty).

    $$ q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚) = \frac{q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€)q(๐ฑโ‚œ| ๐ฑโ‚€)}{q(๐ฑโ‚œโ‚‹โ‚| ๐ฑโ‚€)} $$

    And this distribution $q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€)$ has a closed-form solution.

    Here is why the first step is separated out: If t=1, the $q(๐ฑโ‚|๐ฑโ‚€)$ conditioned on ๐ฑโ‚€ is:

    $$ q(๐ฑโ‚|๐ฑโ‚€) = \frac{q(๐ฑโ‚€|๐ฑโ‚, ๐ฑโ‚€)q(๐ฑโ‚|๐ฑโ‚€)}{q(๐ฑโ‚€|๐ฑโ‚€)} $$

    There is a loop of $q(๐ฑโ‚|๐ฑโ‚€)$ if ๐ฑโ‚€ exists, and other terms $q(๐ฑโ‚€|๐ฑโ‚, ๐ฑโ‚€)$ and $q(๐ฑโ‚€|๐ฑโ‚€)$ don’t make sense.

    Plug the newly conditioned numerator back to the fraction, and break it apart based on log rule:

    $$ โˆ‘_{t=2}^T log \frac{q(๐ฑโ‚œ|๐ฑโ‚œโ‚‹โ‚)}{p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)} \ = โˆ‘_{t=2}^T log \frac{q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€)q(๐ฑโ‚œ| ๐ฑโ‚€)}{p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)q(๐ฑโ‚œโ‚‹โ‚| ๐ฑโ‚€)} \\\ \ = โˆ‘_{t=2}^T log \frac{q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€)}{p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)} + โˆ‘_{t=2}^T log \frac{q(๐ฑโ‚œ| ๐ฑโ‚€)}{q(๐ฑโ‚œโ‚‹โ‚| ๐ฑโ‚€)} \\\ $$

    The second term will be simplied to $log \frac{q(๐ฑ_T| ๐ฑโ‚€)}{q(๐ฑโ‚| ๐ฑโ‚€)}$

    Then, the variational lower bound becomes:

    $$ D_{KL}(q(๐ฑ_{1:T}|๐ฑโ‚€) || p(๐ฑ_{0:T})) = \\\ log \frac{q(๐ฑโ‚|๐ฑโ‚€)}{p(๐ฑโ‚€|๐ฑโ‚)} \ + โˆ‘_{t=2}^T log \frac{q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€)}{p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)} \ + log \frac{q(๐ฑ_T| ๐ฑโ‚€)}{q(๐ฑโ‚| ๐ฑโ‚€)} \ - log\ p(๐ฑ_T) \\\ \ \\\ \ = โˆ‘_{t=2}^Tlog \frac{ q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€) }{ p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)} \ + log \frac{q(๐ฑ_T| ๐ฑโ‚€)}{p(๐ฑโ‚€|๐ฑโ‚)} \ - log\ p(๐ฑ_T) \\\ \ \\\ \ = โˆ‘_{t=2}^T log \frac{ q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€) }{ p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)} \ + log \frac{q(๐ฑ_T| ๐ฑโ‚€)}{p(๐ฑ_T)} \ - log p(๐ฑโ‚€|๐ฑโ‚) $$

    Write this formula as KL-divergence, so that a concrete expression can be determined later.

    How are those two fractions written as KL-divergence? $$ \begin{aligned} & โˆ‘_{t=2}^T D_{KL} (q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€) || p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)) \\\ & + D_{KL} (q(๐ฑ_T| ๐ฑโ‚€) || p(๐ฑ_T)) \\\ & - log\ p(๐ฑโ‚€|๐ฑโ‚) \end{aligned} $$

Loss function

The VLB to be minimized is eventually derived as a MSE loss function between the actual noise and the predicted noise.

  1. $D_{KL} (q(๐ฑ_T| ๐ฑโ‚€) || p(๐ฑ_T))$ can be ignored.

    • $q(๐ฑ_T| ๐ฑโ‚€)$ has no learnable parameters because it just adds noise following a schedule.
    • And $p(๐ฑ_T)$ is the noise image sampled from normal distribution. Since $q(๐ฑ_T| ๐ฑโ‚€)$ is the eventual image which is supposed to follow the normal distribution, this KL-divergence should be small.

    Then, the loss only contains the other two terms:

    $$L = โˆ‘_{t=2}^T D_{KL} (q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€) || p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)) - log\ p(๐ฑโ‚€|๐ฑโ‚)$$
  2. $D_{KL} (q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€) || p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ))$ is the MSE between the actual noise and the predicted noise.

    • For the reverse pass, the distribution of the denoised image $p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ)$ has a parametric expression:

      $$p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ) = N(๐ฑโ‚œโ‚‹โ‚; ฮผ_ฮธ(๐ฑโ‚œ,t), ฮฃ_ฮธ(๐ฑโ‚œ,t)) \\\ = N(๐ฑโ‚œโ‚‹โ‚; ฮผ_ฮธ(๐ฑโ‚œ,t), ฮฒ๐ˆ)$$

      where ฮฃ is fixed as ฮฒโ‚œ๐ˆ, and only the mean $ฮผ_ฮธ(๐ฑโ‚œ,t)$ will be learned and represented by a network (output) through the MSE loss of noise as below.

    • For the (“reversed”) forward pass, the distribution of noise-added image $q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€)$ has a closed-form solution, which can be written as a similar expression as p(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ): What’s the derivation?

      $$ q(๐ฑโ‚œโ‚‹โ‚|๐ฑโ‚œ, ๐ฑโ‚€) = N(๐ฑโ‚œโ‚‹โ‚; \tilde ฮผโ‚œ(๐ฑโ‚œ,๐ฑโ‚€), \tilde ฮฒโ‚œ๐ˆ) \\\ \ \\\ \tilde ฮฒโ‚œ = \frac{1- \bar ฮฑโ‚œโ‚‹โ‚}{1-\bar ฮฑโ‚œ} โ‹… ฮฒโ‚œ \\\ \ \\\ \tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€) = \frac{\sqrt{ฮฑโ‚œ} (1-\bar ฮฑโ‚œโ‚‹โ‚) }{1-\bar ฮฑโ‚œ} ๐ฑโ‚œ \ + \frac{\sqrt{\bar ฮฑโ‚œโ‚‹โ‚} ฮฒโ‚œ}{1-\bar ฮฑโ‚œ} ๐ฑโ‚€ \\\ \ \\\ \rm Is there \sqrt{ฮฑโ‚œ} or \sqrt{\bar ฮฑโ‚œ} ? $$

      where the $\tilde ฮฒโ‚œ$ is fixed, so only consider the $\tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€)$, which can be simplified by the one-step forward process expression: $๐ฑโ‚œ = \sqrt{\bar ฮฑโ‚œ} ๐ฑโ‚€ + \sqrt{1 - \bar ฮฑโ‚œ} ฮต$

      $$ ๐ฑโ‚€ = \frac{๐ฑโ‚œ - \sqrt{1 - \bar ฮฑโ‚œ} ฮต}{\sqrt{\bar ฮฑโ‚œ}} $$

      Plug ๐ฑโ‚€ into $\tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€)$, then the mean of the noise-added image doesn’t depend on ๐ฑโ‚€ anymore:

      $$ \begin{aligned} \tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€) & = \frac{\sqrt{ฮฑโ‚œ} (1-\bar ฮฑโ‚œโ‚‹โ‚) }{1-\bar ฮฑโ‚œ} ๐ฑโ‚œ \ + \frac{\sqrt{\bar ฮฑโ‚œโ‚‹โ‚} ฮฒโ‚œ}{1-\bar ฮฑโ‚œ} \ \frac{๐ฑโ‚œ - \sqrt{1 - \bar ฮฑโ‚œ} ฮต}{\sqrt{\bar ฮฑโ‚œ}} \\\ \ \\\ & = ???\ How \ to \ do? \ ??? \\\ & = \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต) \end{aligned} $$

      The mean of the distribution from which the noise-added image (๐ฑโ‚œ,๐ฑโ‚€) at timestep t get sampled out is subtracting some random noise from image ๐ฑโ‚œ.

      ๐ฑโ‚œ is known from the forward process schedule, and the $\tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€)$ is the target for the network to optimize weights to make the predicted mean $ฮผ_ฮธ(๐ฑโ‚œ,t)$ same as $\tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€)$.

    • Since network only output ฮผ, the KL-divergence in the loss function can be simplified in favor of using MSE:

      $$ Lโ‚œ = \frac{1}{2ฯƒโ‚œยฒ} \\| \tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€) - ฮผ_ฮธ(๐ฑโ‚œ,t) \\|ยฒ $$

      This MSE indicates that the noise-added image in the forward process and the noise-removed image in the reverse process should be as close as possible.

      Since the actual mean $\tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€) = \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต)$, where ๐ฑโ‚œ is known, as it’s the input to the network. So the model is essentially estimating the actual $ฮต$ (random noise) every time.

      Hence, the predicted mean $ฮผ_ฮธ(๐ฑโ‚œ,t)$ by the model can be written in the same form as $\tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€)$, where only the noise $ฮต_ฮธ$ has parameters:

      $$ฮผ_ฮธ(๐ฑโ‚œ,t) = \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต_ฮธ((๐ฑโ‚œ,t)))$$

      Therefore, the loss term becomes:

      $$ Lโ‚œ = \frac{1}{2ฯƒโ‚œยฒ} \\| \tilde ฮผ(๐ฑโ‚œ,๐ฑโ‚€) - ฮผ_ฮธ(๐ฑโ‚œ,t) \\|ยฒ \\\ \ = \frac{1}{2ฯƒโ‚œยฒ} \left\\| \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต) \ - \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต_ฮธ(๐ฑโ‚œ,t)) \right\\|ยฒ \\\ \ = \frac{ฮฒโ‚œยฒ}{2ฯƒโ‚œยฒ ฮฑโ‚œ (1-\bar ฮฑโ‚œ)} \\|ฮต - ฮต_ฮธ(๐ฑโ‚œ,t) \\|ยฒ $$

      Disregarding the scaling factor can bring better sampling quality and easier implementation, so the final loss for the KL-divergence is MSE between actual noise and predicted noise at time t:

      $$\\|ฮต - ฮต_ฮธ(๐ฑโ‚œ,t) \\|ยฒ$$
    • Once the mean $ฮผ_ฮธ(๐ฑโ‚œ,t)$ has predicted out based on ๐ฑโ‚œ and t, a “cleaner” image can be sampled from the distribution:

      $$ N(๐ฑโ‚œโ‚‹โ‚; ฮผ_ฮธ(๐ฑโ‚œ,t), \sigma_ฮธ(๐ฑโ‚œ,t)) = N(๐ฑโ‚œโ‚‹โ‚; \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต_ฮธ(๐ฑโ‚œ,t), ฮฒโ‚œ๐ˆ) $$
      By using reparameterization trick, this sampled image is: $$ ๐ฑโ‚œโ‚‹โ‚ = ฮผ_ฮธ(๐ฑโ‚œ,t) + ฯƒฮต \ = \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต_ฮธ(๐ฑโ‚œ,t) + \sqrt{ฮฒโ‚œ}ฮต $$
  3. The last term $log p(๐ฑโ‚€|๐ฑโ‚)$ in the VLB is the predicted distribution for the original image ๐ฑโ‚€. Its goodness is measured by a probability that the original image $๐ฑโ‚€$ gets sampled from the estimated distribution $N(x; ฮผ_ฮธโฑ(๐ฑโ‚,1), ฮฒโ‚)$.

    The probability of an image should be a product of total D pixels. And the probability a pixel should be an integral over an interval [ฮดโ‚‹, ฮดโ‚Š] of the PDF curve:

    $$ p_ฮธ(๐ฑโ‚€|๐ฑโ‚) = โˆ_{i=1}^D โˆซ_{ฮดโ‚‹(xโ‚€โฑ)}^{ฮดโ‚Š(xโ‚€โฑ)} N(x; ฮผ_ฮธโฑ(๐ฑโ‚,1), ฮฒโ‚) dx $$
    • where $xโ‚€$ is the pixel’s ground-truth.
    • $N(x; ฮผ_ฮธโฑ(๐ฑโ‚,1), ฮฒโ‚)$ is the distribution to be integrated.

    This interval is determined based on the actual pixel value as:

    $$ ฮดโ‚Š(x) = \begin{cases} โˆž & \text{if x = 1} \\\ x+\frac{1}{255} & \text{if x < 1} \end{cases}, \quad ฮดโ‚‹(x) = \begin{cases} -โˆž & \text{if x = -1} \\\ x-\frac{1}{255} & \text{if x > -1} \end{cases} $$
    • The original pixel range [0,255] has been normalized to [-1, 1] to align with the standard normal distribution $p(x_T) \sim N(0,1)$

    • If the actual value is 1, the integral upper bound in the distribution is โˆž, and the lower bound is 1-1/255 = 0.996, the width of the interval is from 0.996 to infinity.

      If the actual value is 0.5, the upper bound is 0.5+1/255, and the lower bound is 0.5-1/255, the width of the interval is 2/255.

    pic: area of the true pixel region in two predicted distributions.

    If the area around the actual pixel value under the predicted distribution PDF curve is large, the predicted distribution is good. Howerver, if the area around real pixel value is small, the estimated mean is wrongly located.

    Hence, this probability (log-likelihood) should be maximized, and by condering the minus sign in front of it, the corresponding loss term comes.

    However, the authors got rid of this loss term $-log p(๐ฑโ‚€|๐ฑโ‚)$ when training the network. And the consequense is at inference time, the final step from ๐ฑโ‚ to ๐ฑโ‚€ doesn’t add noise, because this step wasn’t get optimized. Therefore, The difference from other sampling steps is that the predicted ๐ฑโ‚€ doesn’t plus random noise.

    $$ \begin{aligned} \text{t>1:}\quad ๐ฑ_{t-1} &= \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต_ฮธ(๐ฑโ‚œ,t) + \sqrt ฮฒโ‚œ ฮต) \\\ \text{t=1:}\quad ๐ฑ_{t-1} &= \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{ฮฒโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต_ฮธ(๐ฑโ‚œ,t)) \end{aligned} $$

    A simple reason is that we don’t want to add noise to the final denoised clear output image ๐ฑโ‚€. Otherwise, the generated image is low-quality.

The complete loss function is MSE:

$$ \begin{aligned} \rm L_{simple} &= E_{t,๐ฑโ‚€,ฮต} [ || ฮต - ฮต_ฮธ(๐ฑโ‚œ,t)||ยฒ ] \\\ &= E_{t,๐ฑโ‚€,ฮต} [ || ฮต - ฮต_ฮธ( \sqrt{\bar aโ‚œ} ๐ฑโ‚€ + \sqrt{1 - \bar aโ‚œ} ฮต, t) ||ยฒ ] \end{aligned} $$
  • t is sampled from a uniform distribution between 1 and t;
  • ๐ฑโ‚œ is the one-step forward process.

Algorithms

DDPM paper

Training a model:

\begin{algorithm}
\caption{Training}
\begin{algorithmic}
\REPEAT
  \STATE Sample a t from U(0,T)
  \STATE Select an input image ๐ฑโ‚€ from dataset
  \STATE Sample a noise from N(0,๐ˆ)
  \STATE Perform gradient descent with loss: \\\\
  $||ฮต - ฮต_ฮธ(\sqrt{\bar aโ‚œ} ๐ฑโ‚€ + \sqrt{1 - \bar aโ‚œ} ฮต, t)||ยฒ$
\UNTIL{converge}
\end{algorithmic}
\end{algorithm}

Sampling from the learned data distribution by means of reparameterization trick:

\begin{algorithm}
\caption{Sampling}
\begin{algorithmic}
\STATE Sample a noise image $๐ฑ_T \sim N(0,๐ˆ)$
\FOR{t = T:1}
  \COMMENT{Remove noise step-by-step}
  \IF{t=1}
    \STATE ฮต=0
  \ELSE
    \STATE ฮต ~ N(0,๐ˆ)
  \ENDIF

  \STATE $๐ฑโ‚œโ‚‹โ‚ = \frac{1}{\sqrt{ฮฑโ‚œ}} (๐ฑโ‚œ - \frac{1-ฮฑโ‚œ}{\sqrt{1 - \bar ฮฑโ‚œ}} ฮต_ฮธ(๐ฑโ‚œ,t) + \sqrt{ฯƒโ‚œ}ฮต$
  \COMMENT{Reparam trick}

\ENDFOR
\RETURN ๐ฑโ‚€
\end{algorithmic}
\end{algorithm}

In this reparametrization formula change ฮฒโ‚œ and $\sqrt{ฮฒโ‚œ}$ to 1-ฮฑโ‚œ and ฯƒโ‚œ, which are different from the above equation.

Training and Sampling share the common pipeline:

๐ฑ t โ‚œ โ€– U ฮต N โ‹ฎ - e ฮต t _ ฮธ โ€– ยฒ ฮต _ ฮธ โ‹ฏ โ–ถ ฮผ _ ฮธ ( ๐ฑ โ‚œ , t ) โ‹ฏ โ–ถ ๐ฑ โ‚œ โ‚‹ โ‚

Improvements

Improvements from OpenAI’s 2021 papers.

  1. Learn a scale factor for interpolating the upper and lower bound to get a flexible variance:

    $$ฮฃ_ฮธ(xโ‚œ,t) = exp(v\ log ฮฒโ‚œ +(1-v)\ log(1- \tilde{ฮฒโ‚œ}))$$

    v is learned by adding an extra loss term $ฮป L_{VLB}$, and ฮป=0.001.

    $$L_{hybrid} = E_{t,๐ฑโ‚€,ฮต} [ || ฮต - ฮต_ฮธ(๐ฑโ‚œ,t)||ยฒ ] + ฮป L_{VLB}$$
  2. Use cosine noise schedule $f(t)=cos(\frac{t/T+s}{1+s}โ‹…ฯ€/2)ยฒ$ in favor of linear schedule.


Reference

Built with Hugo
Theme Stack designed by Jimmy