Mean Field Variational Inference


Preliminaries

Some key (but non-exhaustive) concepts to be familiar with to help with derivations and general intuition when dealing with topics in structured representation learning!

Statistical Inference as Optimization

Variational auto-encoders (VAEs) on a high level are composed of the following:
  1. Encoder ( ϕ ): learn the μ,σ for zN(μ,σ2) so that we can sample from qϕ(z|x)
  2. Decoder ( θ ): learn a representation of the observations x~pθ(x|z)
Variational inference (VI) is an algorithm for learning the parameters θ and ϕ for approximating the mapping from the inputs / observations x to a process generating such inputs p(x) (true distribution).

To generate samples x~(i)pθ(x) that look like x(i):
  1. Sample from the prior: z(i)pθ(z)
  2. Sample from the conditional likelihood: x(i)pθ(x|z)

Re-parameterization trick 

In VAEs, the encoder determines the distribution of the latent variable z. In order to be able to differentiate through those samples generated by the encoder network, we can re-parameterize the z's to be a linear transformation of a pre-determined distribution like a Gaussian where ϵN(0,1): z=σϵ+μ such that zN(μ,σ) so that we can differentiate through the encoder outputs.

Bayes Theorem

Given prior (assumptions about our observations) pθ(z), posterior (discriminative model) pθ(z|x), likelihood (generative model) pθ(x|z), and joint pθ(z,x) processes:
pθ(x|z)pθ(z)=pθ(z,x)=pθ(z|x)pθ(x)

Inference on z

We want to learn a posterior distribution over latent (un-observable) variables z. The MCMC approach to doing so is via Bayesian inference: pθ(z|x)=pθ(z)pθ(x|z)pθ(z)pθ(x|z)dz
Notice that the marginal distribution in the denominator becomes intractable for |z| very large. 
In VI, we instead learn the parameters ϕ to an approximate posterior qϕ(z) restricted to a tractable family of distributions that is factorizable (aka 'mean-field'). We do stochastic optimization by maximizing the an objective representing a lower bound on the log marginal likelihood:  ELBO.

Check out my JAX implementation of a VAE trained on a toy MNIST classification task here.

logp(.)

Deep learning algorithms often operate on log probabilities for numerical stability (i.e. no nans). The intuition is that when multiplying probabilities (i.e. p_i [0,1] ), they generally (with the exception of pi=1. ) become smaller. However, this becomes a problem if there is limited precision available to represent such small values in the machines executing our computations. It is therefore easier to compute larger magnitudes with more precision, i.e. log(p1p2)=log(p1)+log(p2) where logpi<0.

Jensen's Inequality

Let f(x) be a convex function. Then, E[f(x)]f(E[x])

KL Divergence

Let p,q be the true and learned distributions respectively. KL is a measure of the number of bits log1(2) required to morph one distribution to the other, and we generally minimize this during optimization. There are two forms, the latter of which is utilized in VI:
  • Forward KL ("mode covering"): KL(p(x)||q(z))=p(x)logp(x)q(x)dx. Notice that since q(x) is in the denominator, if it is small, the penalty would sum up to be very large. Thus q(x) "covers" as much of the support p(x)0 as possible, assigning the highest density to the lowest areas of p(x).
  • Reverse KL ("mode seeking"): KL(q(x)||p(z))=q(x)logq(x)p(x)dx. Since p(x) is in the denominator, it cannot be small when q(x) if non-zero. Thus, the algorithm ends up finding "modes", i.e. high-density areas, since we don't have a division by zero error if q(x)=0.

Derivation of ELBO

Training Objective

The VI objective is to minimize the KL-divergence between our approximate and true posterior distribution:
KL[qϕ(z|x)||pθ(z|x)]=qϕ(z|x)logqϕ(z|x)pθ(z|x)dz=qϕ(z|x)logpθ(x)qϕ(z|x)pθ(z,x)dz

We want to surface the likelihood term pθ(x|z) so that we can do standard maximum likelihood estimation (MLE) and also the KL term between the approximate posterior qϕ(z|x) and prior pθ(z):
=qϕ(z|x)logpθ(x)dz+qϕ(z|x)logqϕ(z|x)pθ(z,x)dz=qϕ(z|x)logpθ(x)dz+qϕ(z|x)logqϕ(z|x)pθ(x|z)pθ(z)dz

Notice that since pθ(x) is a probability distribution not dependent on z, it effectively integrates out to 1.
=logpθ(x)+qϕ(z|x)[logqϕ(z|x)pθ(x|z)logpθ(z)]dz=logpθ(x)+qϕ(z|x)logqϕ(z|x)pθ(z)dzqϕ(z|x)logpθ(x|z)dz=logpθ(x)+Ezqϕ(z|x)[logqϕ(z|x)pθ(z)dz]Ezqϕ(z|x)[logpθ(x|z)]

The final two terms represent the negative ELBO, which exactly works out to be our optimizer's training objective of minimizing -log-likelihood + KL.
KL[qϕ(z|x)||pθ(z|x)]=logpθ(x)+KL[qϕ(z|x)||pθ(z)]log likelihood of observations=logpθ(x)ELBO
The marginal log likelihood of the real data logpθ(x) is fixed w.r.t. the variational parameters ϕ, so minimizing KL(approx. data distribution || true data distribution) == maximizing ELBO (or minimizing ELBO ).

Why is the ELBO term a lower bound to the evidence? 

Notice that there are TWO KL terms, one inside the ELBO and another as an overall regularizer, and KL0:
logpθ(x)=log likelihood of observationsKL(approx. post || prior)+KL[qϕ(z|x)||pθ(z|x)]=ELBO+KL(approx. posterior || true posterior)ELBO

By Jensen's inequality, that ELBO is therefore a lower bound on the marginal log likelihood
logpθ(x)=logEzqϕ(z|x)[pθ(x|z)]Ezqϕ(z|x)[logpθ(x|z)]

which makes sense to optimize for this quantity: if the KL divergence is zero, then our bound on logpθ(x) would be tight!

Additional Resources

If you are interested in looking at more advanced topics, I really enjoyed the following slides/papers/blogs:

Comments

Post a Comment