VQ-DRAW: A New Generative Model

Today I am extremely excited to release VQ-DRAW, a project which has been brewing for a few months now. I am really happy about my initial set of results, and I’m so pleased to be sharing this research with the world. The goal of this blog post is to describe, on a high-level, what VQ-DRAW does and how it works. For a more technical description, you can always check out the paper or the code.

In essence, VQ-DRAW is just a fancy compression algorithm. As it turns out, when a compression algorithm is good enough, decompressing random data produces realistic-looking results. For example, when VQ-DRAW is trained to compress pictures of faces into 75 bytes, this is what happens when it decompresses random data:

Samples from a VQ-DRAW model trained on the CelebA dataset.

Pretty cool, right? Of course, there’s plenty of other generative models out there that can dream up realistic looking images. VQ-DRAW has two properties that set it apart:

  1. The latent codes in VQ-DRAW are discrete, allowing direct application to lossy compression without introducing extra overhead for an encoding scheme. Of course, other methods can produce discrete latent codes as well, but most of these methods either don’t produce high-quality samples, or require additional layers of density modeling (e.g. PixelCNN) on top of the latent codes.
  2. VQ-DRAW encodes and decodes data in a sequential manner. This means that reconstruction quality degrades gracefully as latent codes are truncated. This could be used, for example, to implement progressive loading for images and other kinds of data.
MNIST digits as more latent codes become available. The reconstructions gradually become crisper, which is ideal for progressive loading.

The core algorithm behind VQ-DRAW is actually quite simple. At each stage of encoding, a refinement network looks at the current reconstruction and proposes K variations to it. The variation with the best reconstruction error is chosen and passed along to the next stage. Thus, each stage refines the reconstruction by adding log2(K) bits of information to the latent code. When we run N stages of encoding, the latent code is thus N*log2(K) bits long.

VQ-DRAW encoding an MNIST digit in stages. In this simplified example, K = 8 and N = 3.

To train the refinement network, we can simply minimize the L2 distance between images in the training set and their reconstructions. It might seem surprising that this works. After all, why should the network learn to produce useful variations at each stage? Well, as a side-effect of selecting the best refinement options for each input, the refinement network is influenced by a vector quantization effect (the “VQ” in “VQ-DRAW”). This effect pulls different options towards different samples in the training set, causing a sort of clustering to take place. In the first stage, this literally means that the network clusters the training samples in a way similar to k-means.

With this simple idea and training procedure, we can train VQ-DRAW on any kind of data we want. However, I only experimented with image datasets for the paper (something I definitely plan to rectify in the future). Here are some image samples generated by VQ-DRAW:

Samples from VQ-DRAW models trained on four different datasets.
CIFAR digits being decoded stage-by-stage. Each stage adds a few bits of information.

At the time of this release, there is still a lot of work left to be done on VQ-DRAW. I want to see how well the method scales to larger image datasets like ImageNet. I also want to explore various ways of improving VQ-DRAW’s modeling power, which could make it generate even better samples. Finally, I’d like to try VQ-DRAW on various kinds of data beyond images, such as text and audio. Wish me luck!