Mean Field Variational Inference


Preliminaries

Some key (but non-exhaustive) concepts to be familiar with to help with derivations and general intuition when dealing with topics in structured representation learning!

Statistical Inference as Optimization

Variational auto-encoders (VAEs) on a high level are composed of the following:
  1. Encoder ( \( \phi \) ): learn the \(\mu, \sigma \) for \( z \sim \mathcal{N}(\mu, \sigma^2) \) so that we can sample from \(q_{\phi}(z|x)\)
  2. Decoder ( \( \theta \) ): learn a representation of the observations \( \tilde x \sim p_{\theta}(x | z) \)
Variational inference (VI) is an algorithm for learning the parameters \( \theta \) and \( \phi \) for approximating the mapping from the inputs / observations \( x \) to a process generating such inputs \( p(x) \) (true distribution).

To generate samples \( \tilde x^{(i)} \sim p_{\theta}(x) \) that look like \( x^{(i)} \):
  1. Sample from the prior: \( z^{(i)} \sim p_{\theta}(z) \)
  2. Sample from the conditional likelihood: \( x^{(i)} \sim p_{\theta}(x|z) \)

Re-parameterization trick 

In VAEs, the encoder determines the distribution of the latent variable \( z \). In order to be able to differentiate through those samples generated by the encoder network, we can re-parameterize the \( z\)'s to be a linear transformation of a pre-determined distribution like a Gaussian where \( \epsilon \sim N(0, 1) \): \[ z = \sigma \cdot \epsilon + \mu \text{ such that } z \sim \mathcal{N}(\mu, \sigma) \] so that we can differentiate through the encoder outputs.

Bayes Theorem

Given prior (assumptions about our observations) \( p_{\theta}(z) \), posterior (discriminative model) \(p_{\theta}(z|x) \), likelihood (generative model) \(p_{\theta}(x|z) \), and joint \( p_{\theta}(z, x) \) processes:
$$ p_{\theta}(x|z) p_{\theta}(z) = p_{\theta}(z, x) = p_{\theta}(z|x) p_{\theta}(x) $$

Inference on \( z \)

We want to learn a posterior distribution over latent (un-observable) variables \( z \). The MCMC approach to doing so is via Bayesian inference: $$ p_{\theta}(z|x) = \frac{p_{\theta}(z)p_{\theta}(x|z)}{\int p_{\theta}(z)p_{\theta}(x|z)dz} $$
Notice that the marginal distribution in the denominator becomes intractable for \(|z| \) very large. 
In VI, we instead learn the parameters \( \phi \) to an approximate posterior \( q_{\phi}(z) \) restricted to a tractable family of distributions that is factorizable (aka 'mean-field'). We do stochastic optimization by maximizing the an objective representing a lower bound on the log marginal likelihood:  ELBO.

Check out my JAX implementation of a VAE trained on a toy MNIST classification task here.

\( \log{p(.)} \)

Deep learning algorithms often operate on log probabilities for numerical stability (i.e. no nans). The intuition is that when multiplying probabilities (i.e. p_i \( \in [0, 1] \) ), they generally (with the exception of \( p_i = 1. \) ) become smaller. However, this becomes a problem if there is limited precision available to represent such small values in the machines executing our computations. It is therefore easier to compute larger magnitudes with more precision, i.e. \( log(p_1 p_2) = log(p_1) + log(p_2) \) where \( \log p_i < 0\).

Jensen's Inequality

Let \( f(x) \) be a convex function. Then, $$ \mathbb{E}[f(x)] \geq f(\mathbb{E}[x]) $$

KL Divergence

Let \(p, q \) be the true and learned distributions respectively. KL is a measure of the number of bits \( \log^{-1}(2)\) required to morph one distribution to the other, and we generally minimize this during optimization. There are two forms, the latter of which is utilized in VI:
  • Forward KL ("mode covering"): \( KL(p(x) || q(z)) = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)}dx \). Notice that since \( q(x) \) is in the denominator, if it is small, the penalty would sum up to be very large. Thus \(q(x) \) "covers" as much of the support \( p(x) \neq 0 \) as possible, assigning the highest density to the lowest areas of \( p(x) \).
  • Reverse KL ("mode seeking"): \( KL(q(x) || p(z)) = \int_{-\infty}^{\infty} q(x) \log \frac{q(x)}{p(x)}dx \). Since \( p(x) \) is in the denominator, it cannot be small when \(q(x) \) if non-zero. Thus, the algorithm ends up finding "modes", i.e. high-density areas, since we don't have a division by zero error if \(q(x) = 0\).

Derivation of ELBO

Training Objective

The VI objective is to minimize the KL-divergence between our approximate and true posterior distribution:
$$\begin{align}KL\Big[ q_{\phi}(z|x) || p_{\theta}(z|x) \Big] &= \int q_{\phi}(z|x) \log{\frac{q_\phi(z|x)}{p_{\theta}(z|x)}}dz \\ &=  \int q_{\phi}(z|x) \log{\frac{p_{\theta}(x) q_\phi(z|x)}{p_{\theta}(z, x)}}dz \\ \end{align}$$

We want to surface the likelihood term \( p_{\theta}(x|z) \) so that we can do standard maximum likelihood estimation (MLE) and also the KL term between the approximate posterior \( q_{\phi}(z|x) \) and prior \( p_{\theta}(z) \):
$$ \begin{align}&= \int q_{\phi}(z|x) \log p_{\theta}(x)dz + \int q_{\phi}(z|x) \log \frac{q_{\phi}(z|x)}{p_{\theta}(z,x)}dz \\ &= \int q_{\phi}(z|x) \log p_{\theta}(x)dz + \int q_{\phi}(z|x) \log \frac{q_{\phi}(z|x)}{p_{\theta}(x|z) p_{\theta}(z)}dz \end{align}$$

Notice that since \( p_{\theta}(x) \) is a probability distribution not dependent on \(z \), it effectively integrates out to 1.
$$ \begin{align} &= \log p_{\theta}(x) + \int q_{\phi}(z|x) \left[ \log \frac{q_{\phi}(z|x)}{p_{\theta}(x|z)} - \log p_{\theta}(z) \right] dz \\ &= \log p_{\theta}(x) + \int q_{\phi}(z|x) \log \frac{q_{\phi}(z|x)}{p_{\theta}(z)} dz -  \int q_{\phi}(z|x) \log p_{\theta}(x|z) dz \\ &= \log p_{\theta}(x) + \mathbb{E}_{z \sim q_{\phi}(z|x)}\left[ \log \frac{q_{\phi}(z|x)}{p_{\theta}(z)}dz \right]  - \mathbb{E}_{z \sim q_{\phi}(z|x)} \left[ \log p_{\theta}(x|z) \right] \end{align}$$

The final two terms represent the negative ELBO, which exactly works out to be our optimizer's training objective of minimizing -log-likelihood + KL.
$$ \begin{align} KL\Big[ q_{\phi}(z|x) || p_{\theta}(z|x) \Big] &= \log p_{\theta}(x) + KL\left[ q_{\phi}(z|x) || p_{\theta}(z)\right] - \text{log likelihood of observations} \\ &= \log p_{\theta}(x) -ELBO \end{align}$$
The marginal log likelihood of the real data \(\log p_{\theta}(x) \) is fixed w.r.t. the variational parameters \(\phi \), so minimizing KL(approx. data distribution || true data distribution) == maximizing \( ELBO \) (or minimizing \(-ELBO \) ).

Why is the ELBO term a lower bound to the evidence? 

Notice that there are TWO KL terms, one inside the ELBO and another as an overall regularizer, and \( KL \geq 0 \):
$$ \begin{align}  \log p_{\theta}(x) &=  \text{log likelihood of observations} - \text{KL(approx. post || prior)} + KL\Big[ q_{\phi}(z|x) || p_{\theta}(z|x) \Big] \\ &= ELBO + KL\text{(approx. posterior || true posterior)} \\ &\geq ELBO  \end{align}$$

By Jensen's inequality, that ELBO is therefore a lower bound on the marginal log likelihood
$$ \begin{align}  \log p_{\theta}(x) &= \log \mathbb{E}_{z \sim q_{\phi}(z|x)} \left[ p_{\theta}(x|z) \right] \\ &\geq \mathbb{E}_{z \sim q_{\phi}(z|x)} \left[ \log p_{\theta}(x|z) \right]  \end{align}$$

which makes sense to optimize for this quantity: if the KL divergence is zero, then our bound on \(  \log p_{\theta}(x) \) would be tight!

Additional Resources

If you are interested in looking at more advanced topics, I really enjoyed the following slides/papers/blogs:

Comments

Post a Comment