read: Gen - ELBO | VAE

Auto-Encoding Variational Bayes

Arxiv | Code-keras | Code-cvae

Recognition model with parameters φ, $q_φ(𝐳|𝐱)$, approximates the intractable posterior distribution;

Generative model with parameters θ, $p_θ(𝐳)p_θ(𝐱|𝐳)$, maps a latent variable to a “sample”.

Abstract

  • directed probabilistic model

    The distribution of an i.i.d. dataset with latent variables P(𝐗,𝐙)

  • continuous latent variables with intractable posterior distributions

    The latent variable 𝐳 per datapoint is continuous, so its posterior distribution $p_θ(𝐳|𝐱)$ cannot be computed explicitly.

  • reparameterization of the variational lower bound yields a lower bound estimator

  • Lower bound estimator can be optimized using SGD,

Introduction

  • Variational Bayesian approach involves the optimization of an approximation to the intractable posterior.

    Use q(𝐳) to approximate $p_θ(𝐳|𝐱)$

  • Mean-field approach requires analytical solutions of expectations w.r.t. the approximate posterior

    𝔼_{q(𝐳)} [ ]

  • q(𝐳) is also intractable

Method

Problem scenario

  • Each sample is generated by two steps:

    1. sample a 𝐳 from prior distribution $p_{θ^*}(𝐳)$
    2. Sample an 𝐱 from the conditional distribution $p_{θ^*}(𝐱|𝐳)$

    where θ* is the true parameters.

  • Intractability:

    The integral of the marginal likelihood $p_θ(𝐱) = ∫ p_θ(𝐳) p_θ(𝐱|𝐳) d𝐳 = ∫ p_θ(𝐱,𝐳) d𝐳$ cannot be computed, because 𝐳 is a high-dimensional continuous variable.

    Such that the true posterior density $p_θ(𝐳|𝐱) = \frac{ p_θ(𝐱|𝐳) p_θ(𝐳) }{ p_θ(𝐱) }$ is also intractable, because the denominator marginal likelihood cannot be computed.

    That means the EM algorithm cannot be used, because in each iteration the introduced prior distribution q(𝐳) needs to be equal to $p_θ(𝐳|𝐱)$, which however is intractable.

    And any reasonable mean-field variational bayesian algorithms are intractable, where the $p_θ(𝐳|𝐱)$ is used as the objective of approximating by q(𝐳), but it doesn’t work for $p_θ(𝐳|𝐱)$ that cannot be computed.

  • A large dataset:

    Sampling-based methods, e.g. Monte Carlo EM, would be too slow, because sampling is performed on every datapoint.

Three meaningful problems:

  1. Estimating the parameter θ of the distribution via MLE or MAP can allow mimicking the data-generating process: Sample 𝐳 first then sample the x from the conditional distribution

    • MLE try to find the θ that makes the probability of dataset X given Z maximum. MLE is to estimate parameter θ, while VAE is to estimate the distribution of X.

    • MAP believe that the θ having the maximum posterior probability given a dataset likelihood p(X|θ) and prior probability p(θ) according to, $p(θ|X) = \frac{p(X|θ) p(θ)}{p(X)}$, is the most possible.

    • EM algorithm uses MLE to find the parameter θ of a probabilistic model involving latent variable, where assume the prior probability q(𝐳) = posterior probability $p_θ(𝐳|𝐱)$ of the latent variable, then apply MLE find optimal θ.

  2. Approximating the posterior probability $p_θ(𝐳|𝐱)$ given an observed value 𝐱 and a choice of parameters θ is useful for encoding data.

    The latent variable 𝐳 can be regarded as a latent representation of a datapoint 𝐱.

  3. Approximating the marginal likelihood $p_θ(𝐱)$ enable to perform those inference tasks where the prior p(𝐱) is required, e.g. image denoising and inpainting.

The above three goals can be tackled by a recognition model $q_φ(𝐳|𝐱)$, which is an approximation to the intractable posterior probability $p_θ(𝐳|𝐱)$

The recognition model $q_φ(𝐳|𝐱)$ is a probabilistic encoder, which produces a distribution over all possible values of 𝐳 with given a datapoint 𝐱.

(Each datapoint 𝐱 corresponds to a distribution $q_φ(𝐳|𝐱)$ with some parameters (e.g., μ,σ).)

p ( 𝐳 D | i ( 𝐱 s b f ) t y r r o i g m b i u v E t e n i n c o o n a d e o 𝐱 r f ) 𝐳 𝐳 p ( 𝐱 | D 𝐳 i ( ) s b f t y r r o i g m b i u v D t e e i n c o o n a d e o 𝐳 r f ) 𝐱 𝐱

A 𝐳 produces a conditional distribution p(𝐱|𝐳), from which a datapoint 𝐱 is generated. So the generative model $p_θ(𝐱|𝐳)$ is called probabilistic decoder.

(2023-07-09)

  • So a training step is: sampling a 𝐳 from distribution p(𝐳|𝐱), then with this 𝐳, find the parameter θ that makes the probability p(𝐱|𝐳) largest based on MLE. That’s why the loss funciton has a cross-entropy term that measures the likelihood of output 𝐱’, i.e., log p(𝐱|𝐳).

  • According to KL-divergence $-q(𝐳) log \frac{p(𝐳|𝐱)}{q(𝐳)}$, the approximated posterior q(𝐳) is supposed to equal the real posterior p(𝐳|𝐱), which, however, is intractable. So the authors use network and reparameterization to estimate the parameters (mean, variance) of the posterior.

  • (I think) 𝐳 sampled from posterior can be one value or multiple times and doing average.

2.2 The Variational bound

The marginal likelihood of the dataset with N datapoints is a sum over the marginal likelihoods of individual datapoints:

$log\ p_θ(𝐱⁽¹⁾, …, 𝐱⁽ᴺ⁾ ) = ∑ᵢ₌₁ᴺ log\ p_θ(𝐱⁽ⁱ⁾)$

  • Rewrite it by introducing posterior approximation $q_φ(𝐳|𝐱⁽ⁱ⁾)$ to make up KL divergence:

    $$ \begin{aligned} &log\ p_θ(𝐱⁽ⁱ⁾) \\ &= log\ (\frac{p_θ(𝐱⁽ⁱ⁾|𝐳) p_θ(𝐳)}{p_θ(𝐳|𝐱⁽ⁱ⁾)} ) \\ &= log\ (p_θ(𝐱⁽ⁱ⁾|𝐳) p_θ(𝐳)) - log\ p_θ(𝐳|𝐱⁽ⁱ⁾) \\ &= log\ (p_θ(𝐱⁽ⁱ⁾|𝐳) p_θ(𝐳)) - log\ p_θ(𝐳|𝐱⁽ⁱ⁾) \\ &\quad + log\ q_φ(𝐳|𝐱⁽ⁱ⁾) - log\ q_φ(𝐳|𝐱⁽ⁱ⁾) \\ &= log\ \frac{p_θ(𝐱⁽ⁱ⁾|𝐳) p_θ(𝐳)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} - log\ \frac{p_θ(𝐳|𝐱⁽ⁱ⁾)}{q_φ(𝐳|𝐱⁽ⁱ⁾)}\\ \end{aligned} $$

  • Compute expectations w.r.t. the approximate posterior $q_φ(𝐳|𝐱)$ for both side:

    $$ ∫q_φ(𝐳|𝐱⁽ⁱ⁾) log\ p_θ(𝐱⁽ⁱ⁾) d𝐳 = \\ ∫q_φ(𝐳|𝐱⁽ⁱ⁾) \left[log\ \frac{p_θ(𝐱⁽ⁱ⁾|𝐳) p_θ(𝐳)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} - log\ \frac{p_θ(𝐳|𝐱⁽ⁱ⁾)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} \right] d𝐳 $$

  • Left-hand side remains marginal likelihood $log\ p_θ(𝐱⁽ⁱ⁾)$ (i.e. $𝔼_{q_φ(𝐳|𝐱⁽ⁱ⁾)} [ log\ p_θ(𝐱⁽ⁱ⁾) ]$) because p(x) has nothing to do with z.

    While right-hand side is the lower bound plus KL divergence:

    $$ \begin{aligned} & log\ p_θ(𝐱⁽ⁱ⁾) = \\ & ∫q_φ(𝐳|𝐱⁽ⁱ⁾) log\ \frac{p_θ(𝐱⁽ⁱ⁾|𝐳) p_θ(𝐳)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} d𝐳 \\ &\quad - ∫q_φ(𝐳|𝐱⁽ⁱ⁾) log\ \frac{p_θ(𝐳|𝐱⁽ⁱ⁾)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} d𝐳 \\ & = ℒ(θ,φ; 𝐱⁽ⁱ⁾) + D_{KL} ( q_φ(𝐳|𝐱⁽ⁱ⁾) || p_θ(𝐳|𝐱⁽ⁱ⁾ ) \\ & \\ & = 𝔼_{q_φ(𝐳|𝐱⁽ⁱ⁾)} [log\ p_θ(𝐱⁽ⁱ⁾,𝐳) - log\ q_φ(𝐳|𝐱⁽ⁱ⁾)] \\ &\quad + D_{KL} ( q_φ(𝐳|𝐱⁽ⁱ⁾) || p_θ(𝐳|𝐱⁽ⁱ⁾ ) \quad (2) \end{aligned} $$

In another way, the lower bound can also be written as eq.(3):

𝓛$(θ,φ; 𝐱⁽ⁱ⁾) = -D_{KL} ( q_φ(𝐳|𝐱⁽ⁱ⁾) || p_θ(𝐳) ) + 𝔼_{q_φ(𝐳|𝐱⁽ⁱ⁾)} [log\ p_θ(𝐱⁽ⁱ⁾|𝐳)]$

whose derivation starts from the conditional probability $p_θ(𝐱⁽ⁱ⁾|𝐳)$:

  • The likelihood of the i-th datapoint: $$ \begin{aligned} & log(p_θ(𝐱⁽ⁱ⁾|𝐳)) \\ &= log( \frac{p_θ(𝐳|𝐱⁽ⁱ⁾)p_θ(𝐱⁽ⁱ⁾))}{p_θ(𝐳)} )\\ &= log( \frac{ p_θ(𝐳|𝐱⁽ⁱ⁾)p_θ(𝐱⁽ⁱ⁾)* q_φ(𝐳|𝐱⁽ⁱ⁾) ) }{ p_θ(𝐳)*q_φ(𝐳|𝐱⁽ⁱ⁾) } ) \\ &= log( \frac{ p_θ(𝐳|𝐱⁽ⁱ⁾)p_θ(𝐱⁽ⁱ⁾)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} ) - log( \frac{p_θ(𝐳)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} ) \end{aligned} $$

  • Compute the expectation w.r.t. approximate posterior $q_φ(𝐳|𝐱⁽ⁱ⁾)$:

    $$ \begin{aligned} & ∫q_φ(𝐳|𝐱⁽ⁱ⁾)\ log(p_θ(𝐱⁽ⁱ⁾|𝐳)) d𝐳 = \\ & ∫q_φ(𝐳|𝐱⁽ⁱ⁾) log( \frac{ p_θ(𝐳|𝐱⁽ⁱ⁾)p_θ(𝐱⁽ⁱ⁾)}{q_φ(𝐳|𝐱⁽ⁱ⁾)})d𝐳\\ & \quad - ∫q_φ(𝐳|𝐱⁽ⁱ⁾) log( \frac{p_θ(𝐳)}{q_φ(𝐳|𝐱⁽ⁱ⁾)} ) d𝐳 \\ & \\ & 𝔼_{q_φ(𝐳|𝐱⁽ⁱ⁾)} [log\ p_θ(𝐱⁽ⁱ⁾|𝐳))] = \\ & \quad ℒ(θ,φ; 𝐱⁽ⁱ⁾) + D_{KL} ( q_φ(𝐳|𝐱⁽ⁱ⁾ || p_θ(𝐳)) \end{aligned} $$

However, using Monte Carlo to estimate gradient of 𝓛 (an expectation) will bring high variance.


2.3 SGVB estimator and AEVB algo

A practical estimator of the lower bound 𝓛 of the likelihood and its derivatives w.r.t. the parameters.

(With some mild conditions for a selected approximate posterior distribution $q_φ(𝐳|𝐱)$,)

Consider a variable $\~𝐳$ that comes from $\~𝐳 = g_φ$(𝛆,𝐱) with 𝛆 ~ p(𝛆), follows the posterior distribution $\~𝐳 \sim q_φ(𝐳|𝐱)$ (or not conditioned distribution $q_φ(𝐳)$).
And $g_φ$(𝛆,𝐱) is a deterministic differentiable transformation of an (auxiliary) noise variable 𝛆.

Since $\~𝐳$ is a transformation of 𝛆, $\~𝐳$ follows the distribution p(𝛆) as well.

  1. Monte Carlo estimation

    Therefore, using Monte Carlo (i.e., averaging the L sampled values) to approximate an expectation of some function $f(𝐳)$ w.r.t. the posterior approximation $q_φ(𝐳|𝐱⁽ⁱ⁾)$ becomes:

    $$ 𝔼_{q_φ(𝐳|𝐱⁽ⁱ⁾)} [f(𝐳)] = 𝔼_{p(ε)} [f( g_φ(ε, 𝐱⁽ⁱ⁾) )] \\ ≃ \frac{1}{L} ∑ₗ₌₁ᴸ f( g_φ(ε⁽ˡ⁾, 𝐱⁽ⁱ⁾) ), $$ where 𝛆⁽ˡ⁾~ p(𝛆).

  2. Approximate lower bound with eq. (2)

    The version A of SGVB estimator comes from eq. (2), the lower bound 𝓛 should be finally equal to that ELBO expectation, 𝓛ᴬ (𝛉,𝛗,𝐱⁽ⁱ⁾) ≃ 𝓛 (𝛉,𝛗,𝐱⁽ⁱ⁾) i.e., the KL divergence=0.

    And that expectation can be approximated via Monte Carlo (sampling), so the lower bound is approximated as:

    $$ \~\mathcal L^A (θ,φ,𝐱⁽ⁱ⁾) = \\ \frac{1}{L} ∑ₗ₌₁ᴸ \left[ log\ p_θ(𝐱⁽ⁱ⁾, 𝐳⁽ⁱ’ˡ⁾) - log\ q_φ(𝐳⁽ⁱ’ˡ⁾| 𝐱⁽ⁱ⁾) \right] \\ $$

    where $𝐳⁽ⁱ’ˡ⁾= g_φ(ε⁽ⁱ’ˡ⁾, 𝐱⁽ⁱ⁾))$ and 𝛆⁽ˡ⁾~ p(𝛆)

  3. Approximate lower bound with eq. (3)

    Since the KL-divergence $D_{KL}( q_φ(𝐳|𝐱⁽ⁱ⁾) || p_θ(𝐳) )$ of eq. (3) can be integrated analytically, it can be substituted into eq. (3), then only that expectation (“expected reconstruction error” $𝔼_{q_φ(𝐳|𝐱⁽ⁱ⁾)} [log\ p_θ(𝐱⁽ⁱ⁾|𝐳))]$) is approximated, so the second version of lower bound approximation: 𝓛ᴮ (𝛉,𝛗,𝐱⁽ⁱ⁾) ≃ 𝓛 (𝛉,𝛗,𝐱⁽ⁱ⁾) is more relatively accurate than the version A.

    $$ \tilde{\mathcal L^B} = -D_{KL}( q_φ(𝐳|𝐱⁽ⁱ⁾) || p_θ(𝐳) ) \\ \qquad + \frac{1}{L} ∑ₗ₌₁ᴸ log\ p_θ(𝐱⁽ⁱ⁾ | 𝐳⁽ⁱ’ˡ⁾) $$

    where $𝐳⁽ⁱ’ˡ⁾= g_φ(ε⁽ⁱ’ˡ⁾, 𝐱⁽ⁱ⁾))$ and 𝛆⁽ˡ⁾~ p(𝛆).

    The KL-divergence there can be interpreted as regularizer, and the optimization objective is enlarging the “expected negative reconstruction error”, i.e., recover original 𝐱 from code 𝐳.

  4. Train with minibatches

    Given a dataset 𝐗 with N datapoints, each time M datapoints are drawn as a minibatch, then the marginal likelihood lower bound of the full dataset batch-by-batch is estimated as:

    $$ \mathcal L(\pmb{θ,φ},𝐱) \simeq \mathcal L^M (\pmb{θ,φ},𝐗ᴹ) = \frac{N}{M} ∑ᵢ₌₁ᴹ \tilde{\mathcal L} (\pmb{θ,φ},𝐱⁽ⁱ⁾) $$

2.4 Reparameterization trick

Previously, 𝐳 is sampled directly, then input to model $p_θ(𝐱|𝐳)$, but the parameter φ of the distribution of 𝐳 is not able to be optimized via gradient descent, since Monte Carlo isn’t differentiable.

p ( z z ) p r i o r , q ( z ) s a z M m . p z C l . i n g z - θ - x

Instead of sampling the 𝐳 directly, but the 𝐳 is derived from the sampled 𝛆, based on the deterministic differentiable transformation: 𝐳 = gᵩ(𝛆, 𝐱).

p ( 𝛆 ) 𝛆 d i s t r i b u t i o s n a M m . p 𝛆 𝛆 C l . i n g z = μ + σ 𝛆 z - θ - x

(The non-differentiable operation (M.C.) is put ahead of the leave nodes on the computational graph.)

Then the parameters (e.g. μ,σ²) of 𝐳’s distribution can be learned by a network with parameters φ.

x - φ l μ o g σ ² z p ( = 𝛆 ) 𝛆 μ d + i s σ s t m r 𝛆 a i p b l u i t n i g o n L z - 𝛆 θ - x R e c c l o t o n i s s o s t n r u -

Such that the Monte Carlo estimate of the expectation is differentiable w.r.t. φ.

The reasoning is as follows:

  • For infinitesimals, there is $$q_φ(𝐳|𝐱)∏ᵢdzᵢ = p(\pmb ε)∏ᵢdεᵢ$$; where zᵢ is one of dimensions, d𝐳 = ∏ᵢ dzᵢ

  • Therefore, $∫ q_φ(𝐳|𝐱) f(𝐳) d𝐳$ = ∫p(𝛆) f(𝐳) d𝛆 = ∫ p(𝛆) $f(g_φ$(𝛆,𝐱)) d𝛆

  • Use Monte Carlo sampling approximation: $∫ q_φ(𝐳|𝐱) f(𝐳) d𝐳$ ≃ 1/L ∑ₗ₌₁ᴸ f( gᵩ(𝛆⁽ˡ⁾, 𝐱)), where 𝛆⁽ˡ⁾ ~ p(𝛆)

The transformation (φ) from 𝛆 to 𝐳 can be learned by back-propagation and gradient descent from the reconstruction loss, since the transformation differentiable.

So the parameters (e.g. μ,σ) of the approximate posterior distribution $q_φ(𝐳|𝐱)$ of 𝐳 can be optimized through φ, and the parameter θ of the generative model is trained jointly.

This’s just a trick, the essence of the algorithm is still the coordinate ascent:

  1. Sample a 𝐳 by sampling an 𝛆:

    1
    2
    3
    4
    
    def reparameterize(mu, logvar): # logvar is log𝛔²
      std = torch.exp(0.5 * logvar) # (e^{log𝛔²})^½
      eps = torch.randn_like(std)
      return mu + eps * std
    
  2. Use 𝐳 to generate 𝐱 with model $p_θ(𝐱|𝐳)$,

    1
    2
    3
    
    def decode(z):
      result = nn.linear(latent_dim, hidden_dims)(z)
      return result
    
  3. then use 𝐱 to produce 𝐳 with model $q_φ(𝐳|𝐱)$

    1
    2
    3
    4
    5
    
    def encode(input):
      result = nn.linear(input_dim, hidden_dim)(input)
      mu = nn.Linear(hidden_dim, latent_dim)(result)
      logvar = nn.Linear(hidden_dim, latent_dim)(result)
      return [mu, logvar]
    

With using reparameterization trick, both the two models can be optimized by gradient descent.

(2024-04-16)

  • Reparameterization trick is a sampling method. It’s similar to inverse transform sampling: using a known distribution (uniform) to sample an unknown distribution.

  • Reparameterization trick makes the sampling from an unknown distribution differentiable, and enables the parameters of the distribution to be optimized.


Loss func

Loss function contains two parts: KL divergence and reconstruction error.

  1. KL divergence (i.e., cross entropy) can be computed with given the prior and assumed posterior of 𝐳.

    For example, let the prior $\rm p_θ(𝐳) = N(𝐳,𝟎,𝐈)$ and assume the posterior qᵩ(𝐳|𝐱) = N(𝐳; 𝛍, 𝛔²𝐈), then cross entropy can be derived by plugging them into Gaussian expression.

  2. The reconstruction error is the log likelihood of the input datapoint log p(𝐱|𝐳).

    • For a datapoint obeying multivariate Bernoulli (Yes or No), the log likelihood of 𝐱 is

      $$ log p(𝐱|𝐳) = log ∏ᵢ₌₁ᴰ yᵢˣⁱ (1-yᵢ)¹⁻ˣⁱ \\ \ = ∑ᵢ₌₁ᴰ [ xᵢlog yᵢ + (1-xᵢ)log (1-yᵢ) ] $$

      where yᵢ should be like a probability (need to do sigmoid activation). So in this case, this loss term is a cross entropy.

    • For a datapoint following multivariate Gaussian distribution N(𝐱; 𝛍, 𝛔²𝐈), refer to Su, Jianlin

      $$ \begin{aligned} &p(𝐱|𝐳) = \frac{1}{∏ᵢ₌₁ᴰ \sqrt{2πσᵢ²(𝐳)}} \rm exp(-\frac{1}{2} ‖\frac{x-\pmb μ(𝐳)}{\pmb σ(𝐳)}‖²) \\ \ \\ &log\ p(𝐱|𝐳) \\ &= log \frac{1}{∏ᵢ₌₁ᴰ \sqrt{2πσᵢ²(𝐳)}} \ -\frac{1}{2} ‖\frac{x-\pmb μ(𝐳)}{\pmb σ(𝐳)}‖² \\ &= -∑ᵢ₌₁ᴰ log \sqrt{2πσᵢ²(𝐳)} -\frac{1}{2} ‖\frac{x-\pmb μ(𝐳)}{\pmb σ(𝐳)}‖² \\ &= -∑ᵢ₌₁ᴰ [\frac{1}{2} (log2π + logσᵢ²(𝐳))] - \frac{1}{2} ‖\frac{x-\pmb μ(𝐳)}{\pmb σ(𝐳)}‖² \\ &= -\frac{D}{2} log2π -∑ᵢ₌₁ᴰlogσᵢ²(𝐳) -\frac{1}{2} ‖\frac{x-\pmb μ(𝐳)}{\pmb σ(𝐳)}‖² \\ \end{aligned} $$

      Normally, the variance 𝛔² will be fixed, so this loss term is only related with mean 𝛍(𝐳):

      -log p(𝐱|𝐳) ~ $\frac{1}{2\pmb σ²}\| 𝐱-\pmb μ(𝐳) \|²$

      Therefore, this loss term is MSE.