Hierarchical Markovian Variational Autoencoders
Diffusion part 2: A first principles approach.
This article is the second in a multi-article series exploring the fundamentals of deep generative models where we discuss Hierarchical Markovian Variational Autoencoders (HMVAE). With advances in text-to-image generation (OpenAI’s DALL-E, Imagen, Midjourney, Stable Diffusion, HuggingFace, etc.), text-to-text (ChatGPT, Chinchilla, Flamingo), speech-to-text (Assembly AI) we should try to understand fundamentally where each of these pieces are coming from as there are many mathematical concepts that are inspiring the design of these deep generative model architectures. We’ll start with decoders and work our way up to Diffusion models, then we’ll explore the fundamentals of encoders looking at unsupervised methods like contrastive learning, then we’ll explore what’s out there, how generative models are changing the world, and finally we’ll provide code examples to help people get started and some of the current limitations of these models. Here’s our outline for diving into diffusion with link references:
- Generative Models, ELBO, Variational Autoencoders
- Hierarchical Markovian Variational Autoencoders (this)
- Diffusion (TBA).
Let’s dive in.
First Principles
Now that we’ve got the fundamentals down for variational autoencoders we can extend the concept from a single encoding and decoding step between two states (x to z) to multiple sequential encoding and decoding steps between a sequence of states.
Some core assumptions are that all the states are of the same dimension and sequentially encode or decode given some groundtruth, x_0 from our dataset D. What we’ve just drawn in figure 1 is a generative process where every state at timestep t can be written as an encoded probability, and a decoded probability for the encoding and decoding process respectively (see figure 2).
Note that an encoder will move forward in time and a decoder will backwards in time during the generation.
We could work with this current setup and derive similar lowerbounds as we did in VAE however we make one more assumption that will simplify our lives: the markov assumption. The markov assumption states that in a markov process (which is what we have in figure 1) the most recent state contains all the information necessary for predicting the next state. Put differently, conditioning on the current state to predict the next state of the process is equivalent to conditioning on the current state and all previous states to predict the next state. In our current setup this translates mathematically to:
Referencing VAE ELBO derivation and using a sequence of states instead of single states we arrive to a similar formulation. We’re going to derive this slightly differently than we did previously using Jenson’s inequality (b/c it makes the proof a bit easier but note that the gap discussed in VAE is equivalent).
ELBO derivation
Why is this helpful?
Let’s take a step back for a second and reflect on what it is we’re trying to do. Recall that just like in the VAE section we’re trying to learn a generative model, something that estimates p(x_0). We do this by thinking about x_0 as being apart of some sequence of traversals through different x_t states via an encoding and decoding process. And under this framework we were able to approximate a lower bound for our probability of p(x_0) as shown in figure 4. Essentially, what that equation in figure 4 is saying mathematically is that the expected ratio under the encoding process of the probability of a decoded trajectory (x_T → x_0) to the probability of an encoded probability (x_0 → x_T) will lower bound the true probability of x_0! By making the problem seemingly more complicated we’ve now imposed some learnable structure that lower bounds the true data distribution of x! Note that structure is represented by θ and φ (learnable parameters for our decoder and encoder).
How do we learn our encoder and decoder?
Approaches
The next steps in our approach starts with thinking about:
- Given samples of our data x_0 in our dataset D, how do we learn θ and φ?
- What model or distribution family should we pick for θ and φ? (Note: generally picking this structure is the job of an ML practitioner).
Approach 0: Align the encoding and decoding process
There are a few ways to address the first step. One way is to try and align the decoding process and the encoding process between the following states in figure 5.
The key idea is that the decoding step should yield the same state as the encoding step (see Figure 7).
Now let’s see if we can massage our ELBO estimate to something that captures this idea. Note that including x_0 is unnecessary under the markov assumption however remembering that we have access to the groundtruth x_0 will become valuable later on. Also note that the following proof could continue off of our proof in figure 4, but for clarity the proof is redundant.
To summarize, what we’ve just done is lowerbound the probability of our sample x_0 in our dataset into three parts: a prior matching term that tries to align our starting prior p(x_T) and the predicted prior (see figure 9), a reconstruction term that tries to minimize the entropy of our reconstruction of x_0 given the encoder distribution (this will prevent the predictions from collapsing to 0), and finally a consistency term that follows from our original goal of aligning the encoding and decoding step (figure 6).
Approach 1: Use Bayes Rule.
Now, if we were to go with approach 0 and we optimized the objective shown in figure 8 we’d likely notice that our results are not all too impressive and suffer from noise. The issue is that our consistency term requires Monte Carlo sampling from two random variables when predicting x_t which makes the variance in our estimates large. This would lead us to another question which was: can we reduce the consistency term to just 1 random variable during prediction? I.e. could we just work with:
The answer is Bayes rule.
As shown in figure 11 and figure 12 if we want to focus on our decoding step (which makes sense b/c eventually we’ll want to sample from the decoding distribution and not the encoding distribution) then we can match the decoding process to the ground truth decoding process, which is an application of Bayes rule! Now instead of needing two variables we just need one. The proof proceeds as follows:
Given this formulation, our denoising matching terms will have lower variance compared to our consistency terms in approach 0. The caveat to this approach is actually knowing the conditional prior in figure 14 which may be difficult to compute.
Also, if the encoder is learned we’re essentially learning two encoding processes simultaneously which may be challenging. What if we assume that our encoders are Gaussian?
Cue Diffusion, which you’ve likely heard so much about. In the next article we’ll discuss what happens when we assume our encoder is Gaussian.
As always, I hope that this is helpful and I look forward to continuing our journey through the land of generative models.
References: