Preliminaries
Some key (but non-exhaustive) concepts to be familiar with to help with derivations and general intuition when dealing with topics in structured representation learning!
Statistical Inference as Optimization
Variational auto-encoders (VAEs) on a high level are composed of the following:
- Encoder (
): learn the for so that we can sample from - Decoder (
): learn a representation of the observations
Variational inference (VI) is an algorithm for learning the parameters and for approximating the mapping from the inputs / observations to a process generating such inputs (true distribution).
To generate samples that look like :
- Sample from the prior:
- Sample from the conditional likelihood:
Re-parameterization trick
In VAEs, the encoder determines the distribution of the latent variable . In order to be able to differentiate through those samples generated by the encoder network, we can re-parameterize the 's to be a linear transformation of a pre-determined distribution like a Gaussian where : so that we can differentiate through the encoder outputs.
Bayes Theorem
Given prior (assumptions about our observations) , posterior (discriminative model) , likelihood (generative model) , and joint processes:
Inference on
We want to learn a posterior distribution over latent (un-observable) variables . The MCMC approach to doing so is via Bayesian inference:
Notice that the marginal distribution in the denominator becomes intractable for very large.
In VI, we instead learn the parameters to an approximate posterior restricted to a tractable family of distributions that is factorizable (aka 'mean-field'). We do stochastic optimization by maximizing the an objective representing a lower bound on the log marginal likelihood: ELBO.
Check out my JAX implementation of a VAE trained on a toy MNIST classification task here.
Deep learning algorithms often operate on log probabilities for numerical stability (i.e. no nans). The intuition is that when multiplying probabilities (i.e. p_i ), they generally (with the exception of ) become smaller. However, this becomes a problem if there is limited precision available to represent such small values in the machines executing our computations. It is therefore easier to compute larger magnitudes with more precision, i.e. where .
Jensen's Inequality
Let be a convex function. Then,
KL Divergence
Let be the true and learned distributions respectively. KL is a measure of the number of bits required to morph one distribution to the other, and we generally minimize this during optimization. There are two forms, the latter of which is utilized in VI:
- Forward KL ("mode covering"):
. Notice that since is in the denominator, if it is small, the penalty would sum up to be very large. Thus "covers" as much of the support as possible, assigning the highest density to the lowest areas of . - Reverse KL ("mode seeking"):
. Since is in the denominator, it cannot be small when if non-zero. Thus, the algorithm ends up finding "modes", i.e. high-density areas, since we don't have a division by zero error if .
Derivation of ELBO
Training Objective
The VI objective is to minimize the KL-divergence between our approximate and true posterior distribution:
We want to surface the likelihood term so that we can do standard maximum likelihood estimation (MLE) and also the KL term between the approximate posterior and prior :
Notice that since is a probability distribution not dependent on , it effectively integrates out to 1.
The final two terms represent the negative ELBO, which exactly works out to be our optimizer's training objective of minimizing -log-likelihood + KL.
The marginal log likelihood of the real data is fixed w.r.t. the variational parameters , so minimizing KL(approx. data distribution || true data distribution) == maximizing (or minimizing ).
Why is the ELBO term a lower bound to the evidence?
Notice that there are TWO KL terms, one inside the ELBO and another as an overall regularizer, and :
By Jensen's inequality, that ELBO is therefore a lower bound on the marginal log likelihood
which makes sense to optimize for this quantity: if the KL divergence is zero, then our bound on would be tight!
Additional Resources
If you are interested in looking at more advanced topics, I really enjoyed the following slides/papers/blogs:
Great article! :)
ReplyDelete