Data and Machines

Suppose the year is 2000 BC and you want to set eyes upon a specific, non-naturally occurring color. To create this color, you must look for the right combination of plants, sand, soils, and clay to produce paint–and even then, there is no guarantee you will succeed exactly. Today, you can search “color picker” on the internet and find a six-digit hex code representing the color in question. Not only can your phone display the color corresponding to the hex code–you can also send this hex code to anybody on Earth and they will see nearly the exact same color. You can just as easily manipulate the hex code, making it darker or lighter, inverting it, or making it “more green”. What’s more, you can use a camera to extract hex codes from natural images and manipulate these in exactly the same way as any other hex code.

What I’ve described is what happens when powerful machines meet powerful data structures. A data structure, in this case a hex code, is useless on its own. You could easily write down “#65bcd4” 200 years ago, but it would not be usable like it is today. The opposite is also true: it would be very difficult to create computer screens or cameras without defining an abstract data structure to represent color. The benefit of data structures is that they give us–and machines–an easy way to manipulate and duplicate state from the physical world.

For most of recorded history, humans were the main machines that operated on data structures. Language is perhaps the biggest example of structured data that is arguably useless on its own, but incredibly powerful when paired with the right machines. Another example is music notation (with the accompanying abstractions such as notes and units of time), where humans (often paired with instruments) are the decoding machines. However, things are changing, and non-human machines can now read, write, and manipulate many forms of data that were once exclusively human endeavors.

The right combination of machines and data structures can be truly revolutionary. For example, the ability to capture images and videos from the world, manipulate them in abstract form within a computer program, and then display the resulting imagery on a screen has transformed how people live their lives.

A good data structure can even shape how people think about the physical world. For example, defining colors in terms of three basis components should not obviously work, and requires knowledge of how the human eye perceives light. But now that we have done this research, any human who learns about the XYZ color space will have a clear picture of exactly which human-perceivable colors can be created and how they relate to one another.

It’s worth noting that data structures paired with powerful machines are not always revolutionary, at least not right away. What if, for example, we tried to do the same thing with smell as we did for color: create a data structure for representing smells using base components, and then build machines for producing and detecting smells using our new data structure. Well, this has been tried many times, but has yet to break into the average person’s household. The reasons to me are a bit unclear, but it doesn’t appear to be a pure machine or data structure bottleneck. Sure, we cannot build the right machines, but we also don’t understand the human olfactory system enough to define a perfect smell data structure.

There are also plenty of examples of existing–although not rigorously defined–data structures that could be brought to life with the right machines. Imagine food 3D printers that can follow a human-written recipe, or machines that apply a full face of makeup to match a photograph. I think millions of people would spend money to download a file and immediately have a face of makeup that looks exactly like Lady Gaga on the red carpet. And these same customers would probably also love to be able to tell Alexa to add more salt to last night’s dinner and print it again next weekend. Here, I think the main bottleneck is that we simply don’t have the machines yet; the data structures themselves are fun and possibly even easy to dream up.

I’d argue that, most of the time, we already have the data structures; the problem is that humans are the only machines that can operate on them. For these sorts of problems, machine learning and robotics may one day offer a reasonable solution. Humans are still the machines making recipes, putting on makeup, cutting hair, etc. The data structures we use for these tasks are often encoded in natural language or images, and putting them into physical form requires dextrous manipulation and intelligence. Even decoding these concepts from the world is sometimes out of reach of current machines (e.g. “describe the haircut in this photograph”). The last example also hints that machine learning may also make it easier to translate between different, largely equivalent data structures.

The end result of this mechanization will be truly amazing, and perhaps frightening. Imagine a world, not so long from now, where things we today consider “uncopyable” or “analog” become digital and easily manipulatable. These could include haircuts, food, physical objects, or even memories and personalities. This seems to be the natural conclusion of technological advancement, and I’m excited and horrified to witness some of it in my lifetime.

VQ-DRAW: A New Generative Model

Today I am extremely excited to release VQ-DRAW, a project which has been brewing for a few months now. I am really happy about my initial set of results, and I’m so pleased to be sharing this research with the world. The goal of this blog post is to describe, on a high-level, what VQ-DRAW does and how it works. For a more technical description, you can always check out the paper or the code.

In essence, VQ-DRAW is just a fancy compression algorithm. As it turns out, when a compression algorithm is good enough, decompressing random data produces realistic-looking results. For example, when VQ-DRAW is trained to compress pictures of faces into 75 bytes, this is what happens when it decompresses random data:

Samples from a VQ-DRAW model trained on the CelebA dataset.

Pretty cool, right? Of course, there’s plenty of other generative models out there that can dream up realistic looking images. VQ-DRAW has two properties that set it apart:

  1. The latent codes in VQ-DRAW are discrete, allowing direct application to lossy compression without introducing extra overhead for an encoding scheme. Of course, other methods can produce discrete latent codes as well, but most of these methods either don’t produce high-quality samples, or require additional layers of density modeling (e.g. PixelCNN) on top of the latent codes.
  2. VQ-DRAW encodes and decodes data in a sequential manner. This means that reconstruction quality degrades gracefully as latent codes are truncated. This could be used, for example, to implement progressive loading for images and other kinds of data.
MNIST digits as more latent codes become available. The reconstructions gradually become crisper, which is ideal for progressive loading.

The core algorithm behind VQ-DRAW is actually quite simple. At each stage of encoding, a refinement network looks at the current reconstruction and proposes K variations to it. The variation with the best reconstruction error is chosen and passed along to the next stage. Thus, each stage refines the reconstruction by adding log2(K) bits of information to the latent code. When we run N stages of encoding, the latent code is thus N*log2(K) bits long.

VQ-DRAW encoding an MNIST digit in stages. In this simplified example, K = 8 and N = 3.

To train the refinement network, we can simply minimize the L2 distance between images in the training set and their reconstructions. It might seem surprising that this works. After all, why should the network learn to produce useful variations at each stage? Well, as a side-effect of selecting the best refinement options for each input, the refinement network is influenced by a vector quantization effect (the “VQ” in “VQ-DRAW”). This effect pulls different options towards different samples in the training set, causing a sort of clustering to take place. In the first stage, this literally means that the network clusters the training samples in a way similar to k-means.

With this simple idea and training procedure, we can train VQ-DRAW on any kind of data we want. However, I only experimented with image datasets for the paper (something I definitely plan to rectify in the future). Here are some image samples generated by VQ-DRAW:

Samples from VQ-DRAW models trained on four different datasets.
CIFAR digits being decoded stage-by-stage. Each stage adds a few bits of information.

At the time of this release, there is still a lot of work left to be done on VQ-DRAW. I want to see how well the method scales to larger image datasets like ImageNet. I also want to explore various ways of improving VQ-DRAW’s modeling power, which could make it generate even better samples. Finally, I’d like to try VQ-DRAW on various kinds of data beyond images, such as text and audio. Wish me luck!

Research Projects That Didn’t Pan Out

Anybody who does research knows that ideas often don’t pan out. However, it is fairly rare to see papers or blogs about negative results. This is unfortunate, since negative results can tell us just as much as positive ones, if not more.

Today, I want to share some of my recent negative results in machine learning research. I’ll also include links to the code for each project, and share some theories as to why each idea didn’t work.

Project 1: reptile-gen

Source code:

Premise: This idea came from thinking about the connection between meta-learning and sequence modeling. Research has shown that sequence modeling techniques, such as self-attention and temporal convolutions (or the combination of the two), can be used as effective meta-learners. I wondered if the reverse was also true: are effective meta-learners also good sequence models?

It turns out that many sequence modeling tasks, such as text and image generation, can be posed as meta-learning problems. This means that MAML and Reptile can theoretically solve these problems on top of nothing but a feedforward network. Instead of using explicit state transitions like an RNN, reptile-gen uses a feedforward network’s parameters as a hidden state, and uses SGD to update this hidden state. More details can be found in the README of the GitHub repository.

Experiments: I applied two meta-learning algorithms, MAML and Reptile, to sequence modeling tasks. I applied both of these algorithms on top of several different feedforward models (the best one was a feedforward network that resembles one step of an LSTM). I tried two tasks: generating sequences of characters, and generating MNIST digits pixel-by-pixel.

Results: I never ended up getting high-quality samples from any of these experiments. For MNIST, I got things that looked digit-esque, but the samples were always very distorted, and my LSTM baseline converged to much better solutions with much less tuning. I did notice a big gap in performance between MAML and Reptile, with MAML consistently winning out. I also noticed that architecture mattered a lot, with the LSTM-like model performing better than a vanilla MLP. Gated activations also seemed to boost the performance of the MLPs, although not by much.

MNIST samples from reptile-gen, trained with MAML.
MNIST samples from reptile-gen, trained with MAML.

Takeaways: My main takeaway from this project was that MAML is truly better than Reptile. Reptile doesn’t back-propagate through the inner-loop, and as a result it seems to have much more trouble modeling long sequences. This is in contrast to our findings in the original Reptile paper, where Reptile performed about as well as MAML. How could this be the case? Well, in that paper, we were testing Reptile and MAML with small inner-loops consisting of less than 100 samples; in this experiment, the MNIST inner-loop had 784 samples, and the inputs were (x,y) indices (which inherently share very little information, unlike similar images).

While working on this project, I went through a few different implementations of MAML. My first implementation was very easy to use without modifying the PyTorch model at all; I didn’t expect such a plug-and-play implementation to be possible. This made my mind much more open to MAML as an algorithm in general, and I’d be very willing to use it in future projects.

Another takeaway is that sequence modeling is hard. We should feel grateful for Transformers, LSTMs, and the like. There are plenty of architectures which ought to be able to model sequences, but fail to capture long-term dependencies in practice.

Project 2: seqtree

Source code:

Premise: As I’ve demonstrated before, I am fascinated by decision tree learning algorithms. Ensembles of decision trees are powerful function approximators, and it’s theoretically simple to apply them to a diverse range of tasks. But can they be used as effective sequence models? Naturally, I wondered if decision trees could be used to model the sequences that reptile-gen failed to. I also wanted to experiment with something I called “feature cascading”, where leaves of some trees in an ensemble could generate features for future trees in the ensemble.

Experiments: Like for reptile-gen, I tried two tasks: MNIST digit generation, and text generation. I tried two different approaches for MNIST digit generation: a position-invariant model, and a position-aware model. In the position-invariant model, a single ensemble of trees looks at a window of pixels above and to the left of the current pixel, and tries to predict the current pixel; in the position-aware model, there is a separate ensemble for each location in the image, each of which can look at all of the previous pixels. For text generation, I only used a position-invariant model.

Results: The position invariant models underfit drastically. For MNIST, they generated chunky, skewed digits. The position-aware model was on the other side of the spectrum, overfitting drastically to the training set after only a few trees in each ensemble. My feature cascading idea was unhelpful, and greatly hindered runtime performance since the feature space grew rapidly with training.

Takeaways: Decision tree ensembles simply can’t do certain things well. There are two possible reasons for this: 1) there is no way to build hierarchical representations with linearly-combined ensembles of trees; 2) the greedy nature of tree building prevents complex relationships from being modeled properly, and makes it difficult to perform complex computations.

Another more subtle realization was that decision tree training is not very easy to scale. With neural networks, it’s always possible to add more neurons and put more machines in your cluster. With trees, on the other hand, there’s no obvious knob to turn to throw more compute at the problem and consistently get better results. Tree building algorithms themselves are also somewhat harder to parallelize, since they rely on fewer batched operations. I had trouble getting full CPU utilization, even on a single 64-core cloud instance.

Project 3: pca-compress

Source code:

Premise: Neural networks contain a lot of redundancy. In many cases, it is possible to match a network’s accuracy with a much smaller, carefully pruned network. However, there are some caveats that make this fact hard to exploit. First of all, it is difficult to train sparse networks from scratch, so sparsity does not help much to accelerate training. Furthermore, the best sparsity results seem to involve unstructured sparsity, i.e. arbitrary sparsity masks that are hard to implement efficiently on modern hardware.

I wanted to find a pruning method that could be applied quickly, ideally before training, that would also be efficient on modern hardware. To do this, I tried a form of rank-reduction where the linear layers (i.e. convolutional and fully-connected layers) were compressed without affecting the final number of activations coming out of each layer. I wanted to do this rank-reduction in a data-aware way, allowing it to exploit redundancy and structure in the data (and in the activations of the network while processing the data). The README of the GitHub repository includes a much more detailed description of the exact algorithms I tried.

Experiments: I tried small-scale MNIST experiments and medium-scale ImageNet experiments. I call the latter “medium-scale” because I reserve “large-scale” for things like GPT-2 and BERT, both of which are out of reach for my compute. I tried pruning before and after training. For ImageNet, I also tried iteratively pruning and re-training. A lot of these experiments were motivated by the lottery ticket hypothesis, which stipulates that it may be possible to train sparse networks from scratch with the right set of initial parameters.

I tried several methods of rank-reduction. The simplest, which was based on PCA, only looked at the statistics of activations and knew nothing about the optimization objective. The more complex approaches, which I call “output-aware”, considered both inputs and outputs, trying to prevent the network’s output from changing too much after pruning.

Results: For my MNIST baseline, I was able to prune networks considerably (upwards of 80%) without any significant loss in accuracy. I also found that an output-aware pruning method was better than the simple PCA baseline. However, results and comparisons on this small baseline did not accurately predict ImageNet results.

On ImageNet, PCA pruning was uniformly the best approach. A pre-trained ResNet-18 pruned with PCA to 50% rank across all convolutional layers experienced a 10% decrease in top-1 performance before any tuning or re-training. With iterative re-training, this gap was reduced to closer to 1.8%, which is still worse than the state-of-the-art. My output-aware pruning methods resulted in a performance gap closer to 30% before re-training (much worse than PCA).

I never got around to experimenting with larger architectures (e.g. ResNet-50) and more severe levels of sparsity (e.g. 90%). Some day I may revisit this project and try such things, but at the moment I simply don’t have the compute available to run these experiments.

At the end of the day, why would I look at these results and consider this a research project that “didn’t pan out”? Mostly because I never managed to get the reduction in compute that I was hoping to achieve. My hope was that I could prune networks without a ton of compute-heavy re-training. My grand ambition was to figure out how to prune networks at or near initialization, but this did not pan out either. My only decent results were with iterative pruning, which is computationally expensive and defeats most of the purpose of the exercise. Perhaps my pruning approach could be used to find good, production ready models, but it cannot be used (as far as I know) to speed up training time.

Takeaways: One takeaway is that results on small-scale experiments don’t always carry over to larger experiments. I developed a bunch of fancy pruning algorithms which worked well on MNIST, but none of them beat the PCA baseline on ImageNet. I never quite figured out why PCA pruning worked the best on ImageNet, but my working theory is that the L2 penalty used during training resulted in activations that had little-to-no variance in discriminatively “unimportant” directions.

Competing in the Obstacle Tower Challenge

I had a lot of fun competing in the Unity Obstacle Tower Challenge. I was at the top of the leaderboard for the majority of the competition, and for the entirety of Round 2. By the end of the competition, my agent was ranked at an average floor of 19.4, greater than the human baseline (15.6) in Juliani et al., and greater than my own personal average performance. This submission outranked all of the other submissions, by a very large margin in most cases.

So how did I do it? The simple answer is that I used human demonstrations in a clever way. There were a handful of other tricks involved as well, and this post will briefly touch on all of them. But first, I want to take a step back and describe how I arrived at my final solution.

Before I looked at the Obstacle Tower environment itself, I assumed that generalization would be the main bottleneck of the competition. This assumption mostly stemmed from my experience creating baselines for the OpenAI Retro Contest, where every model generalized terribly. As such, I started by trying a few primitive solutions that would inject as little information into the model as possible. These solutions included:

  • Evolving a policy with CMA-ES
  • Using PPO to train a policy that looked at tiny observations (e.g. 5×5 images)
  • Using CEM to learn an open-loop action distribution that maximized rewards.

However, none of these solutions reached the fifth floor, and I quickly realized that a PPO baseline did better. Once I started tuning PPO, it quickly became clear that generalization was not the actual bottleneck. It turned out that the 100 training seeds were enough for standard RL algorithms to generalize fairly well. So, instead of focusing on generalization, I simply aimed to make progress on the training set (with a few exceptions, e.g. data augmentation).

My early PPO implementation, which was based on anyrl-py, was hitting a wall at the 10th floor of the environment. It never passed the 10th floor, even by chance, indicating that the environment posed too much of an exploration problem for standard RL. This was when I decided to take a closer look at the environment to see what was going on. It turned out that the 10th floor marked the introduction of the Sokoban puzzle, where the agent must push a block across a room to a square target marked on the floor. This involves taking a consistent set of actions for several seconds (on the order of ~50 timesteps). So well for traditional RL.

At this point, other researchers might have tried something like Curiosity-driven Exploration or Go-Explore. I didn’t even give these methods the time of day. As far as I can tell, these methods all have a strong inductive bias towards visually simple (often 2-dimensional) environments. Exploration is extremely easy with visually simple observations, and even simple image similarity metrics can be used with great success in these environments. On Obstacle Tower, however, observations depend completely on what the camera is pointed at, where the agent is standing in a room, etc. The agent can see two totally different images while standing in the same spot, and it can see two very similar images while standing in two totally different rooms. Moreover, the first instant of pushing a box for the Sokoban puzzle looks very similar to the final moment of pushing the same box. My hypothesis, then, was that traditional exploration algorithms would not be very effective in Obstacle Tower.

If popular exploration algorithms are out, how do we make the agent solve the Sokoban puzzle? There are two approaches I would typically try here: evolutionary algorithms, which explore in parameter space rather than algorithm space, and human demonstrations, which bypass the problem of exploration altogether. With my limited compute capabilities (one machine with one GPU), I decided that human demonstrations would be more practical, since evolution typically burns through a lot of compute to train neural networks on games.

To start myself off with human demonstrations, I created a simple tool to record myself playing Obstacle Tower. After recording a few games, I used behavior cloning (supervised learning) to fit a policy to my demonstrations. Behavior cloning started overfitting very quickly, so I stopped training early and evaluated the resulting policy. It was terrible, but it did perform better than a random agent. I tried fine-tuning this policy with PPO, and was pleased to see that it learned faster than a policy trained from scratch. However, it did not solve the Sokoban puzzle.

Fast-forward a bit, and behavior cloning + fine-tuning still hadn’t broken through the Sokoban puzzle, even with many more demonstrations. At around this time, I rewrote my code in PyTorch so that I could try other imitation learning algorithms more easily. And while algorithms like GAIL did start to push boxes around, I hadn’t seen them reliably solve the 10th floor. I realized that the problem might involve memory, since the agent would often run around in circles doing redundant things, and it had no way of remembering if it had just seen a box or a target.

So, how did I fix the agent’s memory problem? In my experience, recurrent neural networks in RL often don’t remember what you want them to, and they take a huge number of samples to learn to remember anything useful at all. So, instead of using a recurrent neural network to help my agent remember the past, I created a state representation that I could stack up for the past 50 timesteps and then feed to my agent as part of its input. Originally, the state representation was a tuple of (action, reward, has key) values. Even with this simple state representation, behavior cloning worked way better (the test loss reached a much lower point), and the cloned agent had a much better initial score. But I didn’t stop there, because the state representation still said nothing about boxes or targets.

To help the agent remember things that could be useful for solving the Sokoban puzzle, I trained a classifier to identify common objects like boxes, doors, box targets, keys, etc. I then added these classification outputs to the state tuple. This improved behavior cloning even more, and I started to see the behavior cloned agent solve the Sokoban puzzle fairly regularly.

Despite the agent’s improved memory, behavior cloning + fine-tuning was still failing to solve the Sokoban puzzle, and GAIL wasn’t much of an improvement. It seemed that, by the time the agent started reaching the 10th floor, it had totally forgotten how to push boxes! In my experience, this kind of forgetting in RL is often caused by the entropy bonus, which encourages the agent to take random actions as much as possible. This bonus tends to destroy parts of an agent’s pre-trained behavior that do not yield noticeable rewards right away.

This was about the time that prierarchy came in. In addition to my observation about the entropy bonus destroying the agent’s behavior, I also noticed that the behavior cloned agent took reasonable low-level actions, but it did so in ways that were unreasonable in a high-level context. For example, it might push a box all the way to the corner of a room, but it might be the wrong corner. Instead of using an entropy bonus, I wanted a bonus that would keep these low-level actions in tact, while allowing the agent to solve the high-level problems that it was struggling with. This is when I implemented the KL term that makes prierarchy what it is, with the behavior cloned policy as the prior.

Once prierarchy was in place, things were pretty much smooth sailing. At this point, it was mostly a matter of recording some more demonstrations and training the agent for longer (on the order of 500M timesteps). However, there were still a few other tricks that I used to improve performance:

  • My actual submissions consisted of two agents. The first one solved floors 1-9, and the second one solved floors 10 and onward. The latter agent was trained with starting floors sampled randomly between 10 and 15. This forced it to learn to solve the Sokoban puzzle immediately, rather than perfecting floors 1-9 first.
  • I used a reduced action space, mostly because I found that it made it easier for me to play the game as a human.
  • My models were based on the CNN architecture from the IMPALA paper. In my experience with RL on video games, this architecture learns and generalizes better than the architecture used in the original Nature article.
  • I used Fixup initialization to help train deeper models.
  • I used MixMatch to train the state classifier with fewer labeled examples than I would have needed otherwise.
  • For behavior cloning, I used traditional types of image data augmentation. However, I also used a mirroring data augmentation where images and actions were mirrored together. This way, I could effectively double the number of training levels, since every level came with its mirror image as well.
  • During prierarchy training, I applied data augmentation to the Obstacle Tower environment to help with overfitting. I never actually verified that this was necessary, and it might not have been, but other contestants definitely struggled with overfitting more than I did.
  • I added a small reward bonus for picking up time orbs. It’s unclear how much of an effect this had, since the agent still missed most of the time orbs. This is one area where improvement would definitely result in a better agent.

I basically checked out for the majority of Round 2: I stopped actively working on the contest, and for a lot of the time I wasn’t even training anything for it. Near the end of the contest, when other contestants started solving the Sokoban puzzle, I trained my agent a little bit more and submitted the new version, but it turned out not to have been necessary.

My code can be found on Github. I do not intend on changing the repository much at this point, since I want it to remain a reflection of the solution described in this post.

Prierarchy: Implicit Hierarchies

On first blush, hierarchical reinforcement learning seems like the holy grail of artificial intelligence. Ideally, it performs long-term exploration, learns over long time-horizons, and has many other benefits. Furthermore, HRL makes it possible to reason about an agent’s behavior in terms of high-level symbols. However, in my opinion, pure HRL is too rigid and doesn’t capture how humans actually operate hierarchically.

Today, I’ll propose an idea that I call “the prierarchy” (a combination of “prior” and “hierarchy”). This idea is an alternative way to look at HRL, without the need to define rigid action boundaries or discrete levels of hierarchy. I am using the prierarchy in the Unity Obstacle Tower Challenge, so I will keep a few details close to my chest for now. However, I intend to update this post once the contest is completed.

Defining the Prierarchy

To understand the prierarchy framework, let’s consider character-level language models. When you run a language model on a sentence, there are some parts of the sentence where the model is very certain what the next token should be (low-entropy), and other parts where it is very uncertain (high-entropy). We can think of this as follows: the high-entropy parts of the sentence mark the starts of high-level actions, while the low-entropy parts represent the execution of those high-level actions. For example, the model likely has more entropy at the starts of words or phrases, while it has less entropy near the ends of words (especially long, unique words). As a consequence of this view, we can say that prompting a language model is the same thing as starting a high-level action and seeing how the model executes this high-level action.

Under the prierarchy framework, we initiate “high-level actions” by taking a sequence of low-probability, low-level actions which kick off a long sequence of high-probability, low-level actions. This assumes that we have some kind of prior distribution over low-level actions, possibly conditioned on observations from the environment. This prior basically encodes the lower-levels of the hierarchy that we want to control with a high-level policy.

The nice thing about the prierarchy is that we never had to make any specific assumptions about action boundaries, since the “high-level actions” aren’t real or concrete, they are just our interpretation of a sequence of continuous entropy values. This means, for one thing, that no low- or high-level policy has to worry about stopping conditions.

Practical Application

In order to implement the prierarchy, you first need a prior. A prior, in this case, is a distribution over low-level actions conditioned on previous observations. This is exactly a “policy” in RL. Thus, you can get a prior in any way you can get a policy: you can train a prior on a different (perhaps more dense but misaligned) reward function; you can train a prior via behavior cloning; you can even hand-craft a prior using some kind of approximate solution to your problem. Whatever method you take, you should make sure that the prior is capable of performing useful behaviors, even if these behaviors occur in the wrong order, for the wrong duration of time, etc.

Once you have a prior, you can train a controller policy fairly easily. First, initialize the policy with the prior, and then perform a policy gradient algorithm with a KL regularizer KL(policy|prior) instead of an entropy bonus. This algorithm is basically saying: “do what the prior says, and inject information into it when you need to in order to achieve rewards”. Note that, if the prior is the uniform distribution over actions, then this is exactly equivalent to traditional policy gradient algorithms.

If the prior is low entropy, then you should be able to significantly increase the discount factor. This is because the noise of policy gradient algorithms scales with the amount of information that is injected into the trajectories by sampling from the policy. Also, if you select a reasonable prior, then the prior is likely to explore much better than a random agent. This can be useful in very sparse-reward environments.

Let’s look at a simple (if not contrived) example. In RL, frame-skip is used to help agents take more consistent actions and thus explore better. It also helps for credit assignment, since the policy makes fewer decisions per span of time in the environment. It should be clear that frame-skip defines a prior distribution over action sequences, where every Nth action completely determines the next N-1 actions. My claim, which I won’t demonstrate today, is that you can achieve a similar effect by pre-training a recurrent policy to mimic the frame-skip action distribution, and then applying a policy gradient algorithm to this policy with an appropriately modified discount factor (e.g. 0.99 might become 0.990.25) and a KL regularizer against the frame-skip prior. Of course, it would be silly to do this in practice, since the prior is so easy to bake into the environment.

There are tons of other potential applications of this idea. For example, you could learn a better action distribution via an evolutionary algorithm like CMA-ES, and then use this as the prior. This way, the long-term exploration benefits of Evolution Strategies could be combined with the short-term exploitation benefits of policy gradient methods. One could also imagine learning a policy that controls a language model prior to write popular tweets, or one that controls a self-driving car prior for the purposes of navigation.


The main benefit of HRL is really just that it compresses sequences of low-level actions, resulting in better exploration and less noise in the high-level policy gradient. That same compression can be achieved with a good low-level prior, without the need to define explicit action boundaries.

As a final note, the prierarchy training algorithm looks almost identical to pre-training + fine-tuning, indicating that it’s not a very special idea at all. Nonetheless, it seems like a powerful one, and perhaps it will save us from wasting a lot of time on HRL.

Solving murder with Go

I recently saw a blog post entitled Solving murder with Prolog. It was a good read, and I recommend it. One interesting aspect of the solution is that it reads somewhat like the original problem. Essentially, the author translated the clues into Prolog, and got a solution out automatically.

This is actually an essential skill for programmers in general: good code should be readable, almost like plain English—and not just if you’re programming in Prolog. So today, I want to re-solve this problem in Go. As I go, I’ll explain how I optimized for readability. For the eager, here is the final solution.

Constants are very useful for readability. Instead of having a magic number sitting around, try to use an intelligible word or phrase! Let’s create some constants for rooms, people, and weapons:

const (
	Bathroom = iota

const (
	George = iota

const (
	Bag = iota

The way I am solving the puzzle is as follows: we generate each possible scenario, check if it meets all the clues, and if so, print it out. To do this, we can store the scenario in a structure. Here’s how I chose to represent this structure, which I call Configuration:

type Configuration struct {
	People  []int
	Weapons []int

Here, we represent the solution by storing the person and weapon for each room. For example, to see the person in the living room, we check cfg.People[LivingRoom]. To see the weapon in the living room, we do cfg.Weapons[LivingRoom]. This will make it easy to implement clues in a readable fashion. For example, to verify the final clue, we simply check that cfg.Weapons[Pantry] == Gas, which reads just like the original clue.

Next, we need some way to iterate over candidate solutions. Really, we just want to iterate over possible permutations of people and weapons. Luckily, I have already made a library function to generate all permutations of N indices. Using this API, we can enumerate all configurations like so:

for people := range approb.Perms(6) {
	for weapons := range approb.Perms(6) {
		cfg := &Configuration{people, weapons}
		// Check configuration here.

Note how this reads exactly like what we really want to do. We are looping through permutations of people and weapons, and we can see that easily. This is because we moved all the complex permutation generation logic into its own function; it’s abstracted away in a well-named container.

Now we just have to filter the configurations using all the available clues. We can make a function that checks if each clue is satisfied, giving us this main loop:

for people := range approb.Perms(6) {
	for weapons := range approb.Perms(6) {
		cfg := &Configuration{people, weapons}
		if !clue1(cfg) || !clue2(cfg) || !clue3(cfg) ||
			!clue4(cfg) || !clue5(cfg) ||
			!clue6(cfg) || !clue7(cfg) ||
			!clue8(cfg) || !clue9(cfg) {
		fmt.Println("killer is:", cfg.People[Pantry])

Next, before we implement the actual clues, let’s implement some helper functions. Well-named helper functions greatly improve readability. Note in particular the isMan helper. It’s very clear what it does, and it abstracts away an arbitrary-seeming 3.

func isMan(person int) bool {
	return person < 3

func indexOf(l []int, x int) int {
	for i, y := range l {
		if y == x {
			return i

Finally, we can implement all the clues in a human-readable fashion. Here is what they look like:

func clue1(cfg *Configuration) bool {
	return isMan(cfg.People[Kitchen]) && (cfg.Weapons[Kitchen] == Knife || cfg.Weapons[Kitchen] == Gas)

func clue2(cfg *Configuration) bool {
	if cfg.People[Study] == Barbara {
		return cfg.People[Bathroom] == Yolanda
	} else if cfg.People[Study] == Yolanda {
		return cfg.People[Bathroom] == Barbara
	} else {
		return false

func clue3(cfg *Configuration) bool {
	bagRoom := indexOf(cfg.Weapons, Bag)
	if cfg.People[bagRoom] == Barbara || cfg.People[bagRoom] == George {
		return false
	return bagRoom != Bathroom && bagRoom != DiningRoom

func clue4(cfg *Configuration) bool {
	if cfg.Weapons[Study] != Rope {
		return false
	return !isMan(cfg.People[Study])

func clue5(cfg *Configuration) bool {
	return cfg.People[LivingRoom] == John || cfg.People[LivingRoom] == George

func clue6(cfg *Configuration) bool {
	return cfg.Weapons[DiningRoom] != Knife

func clue7(cfg *Configuration) bool {
	return cfg.People[Study] != Yolanda && cfg.People[Pantry] != Yolanda

func clue8(cfg *Configuration) bool {
	firearmRoom := indexOf(cfg.Weapons, Firearm)
	return cfg.People[firearmRoom] == George

func clue9(cfg *Configuration) bool {
	return cfg.Weapons[Pantry] == Gas

Most of these clues read almost like the original. I have to admit that the indexOf logic isn’t perfect. Let me know how/if you’d do it better!

What I Don’t Know

Starting this holiday season, I want to take some time every day to focus on a broader set of topics than just machine learning. While ML is extremely valuable, working on it day after day has given me tunnel vision. I think it’s important to remind myself that there’s more out there in the world of technology.

This post is going to be an aspirational one. I started by listing a ton of things I’m embarrassed to know nothing about. Then, in the spirit of self-improvement, I came up with a list of project ideas for each topic. I find that I learn best by doing, so I hope these projects will help me master areas where I have little or no prior experience. And who knows, maybe others will benefit from this list as well!


My knowledge of networking is severely lacking. First of all, I have no idea how the OS network stack works on either Linux or macOS. Also, I’ve never configured a complex network (e.g. for a datacenter), so I have a limited understanding of how routing works.

I am so ignorant when it comes to networking that I often struggle to formulate questions about it. Hopefully, my questions and project ideas actually make sense and turn out to be feasible.


  • What APIs does your OS provide to intercept or manipulate network traffic?
  • What actually happens when you connect to a WiFi network?
  • What is a network interface? How is traffic routed through network interfaces?
  • How do VPNs work internally (both on the server and on the client)?
  • How flexible are Linux networking primitives?
  • How does iptables work on Linux? What’s it actually do?
  • How do NATs deal with different kinds of traffic (e.g. ICMP)?
  • How does DNS work? How do DNS records propagate? How do custom nameservers (e.g. with NS records) work? How does something like iodine work?


  • Re-implement something like Little Snitch.
  • Try using the Berkeley Packet Filter (or some other API) to make a live bandwidth monitor.
  • Implement a packet-level WiFi client that gets all the way up to being able to make DNS queries. I started this with gofi and wifistack, but never finished.
  • Implement a program that exposes a fake LAN with a fake web server on some fake IP address. This will involve writing your own network stack.
  • Implement a user-space NAT that exposes a fake “gateway” through a tunnel interface.
  • Re-implement iodine in a way that parallelizes packet transmission to be faster on satellite internet connections (like on an airplane).
  • Re-implement something like ifconfig using system calls.
  • Connect your computer to both Ethernet and WiFi, and try to write a program that parallelizes an HTTP download over both interfaces simultaneously.
  • Write a script to DOS a VPN server by allocating a ton of IP addresses.
  • Write a simple VPN-like protocol and make a server/client for it.
  • Try to set something up where you can create a new Docker container and assign it its own IP address from a VPN.
  • Try to implement a simple firewall program that hooks into the same level of the network stack as iptables.
  • Implement a fake DNS server and setup your router to use it. The DNS server could forward most requests to a real DNS server, but provide fake addresses for specific domains of your choosing. This would be fun for silly pranks, or for logging domains people visit.
  • Try to bypass WiFi paywalls at hotels, airports, etc.


Cryptocurrencies are extremely popular right now. So, as a tech nerd, I feel kind of lame knowing nothing about them. Maybe I should fix that!


  • How do cryptocurrencies actually work?
  • What kinds of network protocols do cryptocurrencies use?
  • What does it mean that Ethereum is a distributed virtual machine?
  • What computations are actually involved in mining cryptocurrencies?


  • Implement a toy cryptocurrency.
  • Write a script that, without using high-level APIs, transfers some cryptocurrency (e.g. Bitcoin) from one wallet to another.
  • Write a small program (e.g. “Hello World”) that runs on the Ethereum VM. I honestly don’t even know if this is possible.
  • Try writing a Bitcoin mining program from scratch.

Source Control

Before OpenAI, I worked mostly on solitary projects. As a result, I only used a small subset of the features offered by source control tools like Git. I never had to deal with complex merge conflicts, rebases, etc.


  • What are some complicated use-cases for git rebase?
  • How do code reviews typically work on large open source projects?
  • What protocol does git use for remotes? Is a Github repository just a .git directory on a server, or is there more to it than that?
  • What are some common/useful git commands besides git push, git pull, git add, git commit, git merge, git remote, git checkout, git branch, and git rebase? Also, what are some unusual/useful flags for the aforementioned commands?
  • How do you actually set up an editor to work with git add -p?
  • How does git store data internally?


  • Try to write a script that uses sockets/HTTPS/SSH to push code to a Github repo. Don’t use any local git commands or APIs.
  • On your own repos, intentionally get yourself into source control messes that you have to figure your way out of.
  • Submit more pull requests to big open source projects.
  • Read a ton of git man pages.
  • Write a program from scratch that converts tarballs to .git directories. The .git directory would represent a repository with a single commit that adds all the files from the tarball.

Machine Learning

Even though I work in ML, I sometimes forget to keep tabs on the field as a whole. Here’s some stuff that I feel I should brush up on:


  • How do SOTA object detection systems work?
  • How do OCR systems deal with variable-length strings in images?
  • How do neural style transfer and CycleGAN actually work?


  • Make a program that puts boxes around people’s noses in movies.
  • Make a captcha cracker.
  • Make a screenshot-to-text system (easy to generate training data!).
  • Try to make something that takes MNIST digits and makes them look like SVHN digits.


For something so ubiquitous, phones are still a mystery to me. I’m not sure how easy it is to learn about phones as a user hacking away in his apartment, but I can always try!


  • What does the SMS protocol actually look like? How is the data transmitted? Why do different carriers seem to have different character limits?
  • How does modern telephony work? How are calls routed?
  • Why is it so easy to spoof a phone number, and how does one do it?
  • How are Android apps packaged, and how easy is it to reverse engineer them?


  • Try reverse engineering SMS apps on your phone. Figure out what APIs deal with SMS messages, and how messages get from the UI all the way to the cellular antenna. Try to encode and send an SMS message from as low a level as possible.
  • Get a job at Verizon/AT&T/Sprint/T-Mobile. This is probably not worth it, but telephony is one of those topics that seem pretty hard to learn about from the outside.

Misc. Tools

I don’t take full advantage of many of the applications I use. I could probably get a decent productivity boost by simply learning more about these tools.


  • How do you use ViM macros?
  • What useful keyboard shortcuts does Atom provide?
  • How do custom go get URLs work?
  • How are man pages formatted, and how can I write a good man page?


  • Use a ViM macro to un-indent a chunk of lines by one space.
  • For a day, don’t let yourself use the mouse at all in your text editor. Lookup keyboard shortcuts all you want.
  • For a day, use your editor for all development (including running your code). Feasible with e.g. Hydrogen.
  • Make a go get service that allows for semantic versioning wildcards (e.g. go get
  • Write man pages for some existing open source projects. I bet everybody will thank you.

Adversarial Train/Test Splits

The MNIST dataset is split into 60K training images and 10K test images. Everyone pretty much assumes that the training images are enough to learn how to classify the test images, but this is not necessarily the case. So today, in the theme of being a contrarian, I’ll show you how to split the 70K MNIST images in an adversarial manner, such that models trained on the adversarial training set perform poorly on the adversarial test set.

Class imbalance is the most obvious way to make an adversarial split. If you put all the 2’s in the test set, then a model trained on the training set will miss all those 2’s. However, since this answer is fairly boring and doesn’t say very much, let’s assume for the rest of this post that we want class balance. In other words, we want the test set to contain 1K examples of each digit. This also means that, just by looking at the labels, you won’t be able to tell that the split is adversarial.

Here’s another idea: we can train a model on all 70K digits, then get the test set by picking the digits that have the worst loss. What we end up with is a test set that contains the “worst” 1K examples of each class. It turns out that this approach gives good results, even with fairly simple models.

When I tried the above idea with a small, single-layer MLP, the results were surprisingly good. The model, while extremely simple, gets 97% test accuracy on the standard MNIST split. On the adversarial split, however, it only gets 81% test accuracy. Additionally, if you actually look at the samples in the training set versus those in the test set, the difference is quite noticeable. The test set consists of a bunch of distorted, smudgy, or partially-drawn digits.

Example images from the adversarial train/test split.
Training samples (left), testing samples (right).

In a sense, the result I just described was obtained by cheating. I used a certain architecture to generate an adversarial split, and then I used the same architecture to test the effectiveness of that split. Perhaps I just made a split that’s hard for an MLP, but not hard in general. To truly demonstrate that the split is adversarial, I needed to test it on a different architecture. To this end, I tried a simple CNN that gets 99% accuracy on the standard test set. On the MLP-generated adversarial split, this model gets a whopping 87% test accuracy. So, it’s pretty clear that the adversarial split works across architectures.

There’s plenty of directions to take this from here. The following approaches could probably be used to generate even worse splits:

  • Use an ensemble of architectures to select the adversarial test set.
  • Use some kind of iterative process, continually making the split worse and worse.
  • Try some kind of clustering algorithm to make a test set that is a fundamentally different distribution than the training set.

Maybe this result doesn’t mean anything, but I still think it’s cool. It shows that the way you choose a train/test split can matter. It also shows that, perhaps, MNIST is not as uniform as people think.

The code for all of this can be found here.

Decision Trees as RL Policies

In supervised learning, there are very good “shallow” models like XGBoost and SVMs. These models can learn powerful classifiers without an artificial neuron in sight. So why, then, is modern reinforcement learning totally dominated by neural networks? My answer: no good reason. And now I want to show everyone that shallow architectures can do RL too.

Right now, using absolutely no feature engineering, I can train an ensemble of decision trees to play various video games from the raw pixels. The performance isn’t comparable to deep RL algorithms yet, and it may never be for vision-based tasks (for good reason!), but it’s fairly impressive nonetheless.

A tree ensemble playing Atari Pong

So how exactly do I train shallow models on RL tasks? You might have a few ideas, and so did I. Today, I’ll just be telling you about the one that actually worked.

The algorithm itself is so simple that I’m almost kind of embarrassed by my previous (failed) attempts at tree-based RL. Essentially, I use gradient boosting with gradients from a policy gradient estimator. I call the resulting algorithm policy gradient boosting. In practice, I use a slightly more complex algorithm (a tree-based variant of PPO), but there is probably plenty of room for simplification.

With policy gradient boosting, we build up an ensemble of trees in an additive fashion. For every batch of experience, we add a few more trees to our model, making minor adjustments in the direction of the policy gradient. After hundreds or even thousands of trees, we can end up with a pretty good policy.

Now that I’ve found an algorithm that works pretty well, I want to figure out better hyper-parameters for it. I doubt that tree-based PPO is the best (or even a good) technique, and I doubt that my regularization heuristic is very good either. Yet, even with these somewhat random choices, my models perform well on very difficult tasks! This has convinced me that shallow architectures could really disrupt the modern RL landscape, given the proper care and attention.

All the code for this project is in my treeagent repository, and there are some video demonstrations up on this YouTube playlist. If you’re interested, feel free to contribute to treeagent on Github or send me a PM on Twitter.

Keeping Tabs On All My Neural Networks

When I’m out in public, I look at my watch a lot. It’s not because I’m nervous, or because I’m obsessed with the time. It’s because I’m checking on my neural networks.

At any given moment, I probably have four different machines crunching through ML tasks (e.g. training neural networks or downloading data). To keep tabs on all these machines, I use my own logging system called StatusHub. With StatusHub, I can use my phone, my watch, my tablet, or my laptop to see logs from every job across all of my machines. On my watch, I see a scrollable list that looks like this:

On my phone or laptop, I can see the same log through a web UI:

I can even view the log through the command-line, but I won’t bore you with a picture of that one.

Pushing log messages

You push log messages to a StatusHub server with the sh-log command. Without sh-log, I might train a neural network like so:

$ go run *.go
2017/07/03 18:32:57 done 4002029568 updates: cost=0.108497
2017/07/03 18:32:58 done 4002168832 updates: cost=0.114127
2017/07/03 18:32:59 done 4002308096 updates: cost=0.109726

As you can see, the program already produces log messages, but annoyingly they only go to standard error. To push the messages to StatusHub, we can simply use sh-log:

$ sh-log TrainEmbedding go run *.go
2017/07/03 18:32:57 done 4002029568 updates: cost=0.108497
2017/07/03 18:32:58 done 4002168832 updates: cost=0.114127
2017/07/03 18:32:59 done 4002308096 updates: cost=0.109726

In the above example, sh-log executes go run *.go and echoes the standard output/error to a StatusHub server (which is configured via environment variables). The first argument to sh-log is the service name, which helps to distinguish between different jobs in the log. If you look back at the screenshots from the beginning of this post, the service names should stand out right away.

The sh-log command also plays nicely with UNIX pipes. If you don’t provide a command for sh-log to run, it reads directly from standard input. For example, this is how I log information about my GPU:

$ nvidia-smi | head -n 9 | tail -n 1 | cut -b 3-77 | sed -e 's/\s\s*/ /g' | sh-log GPU
23% 35C P8 10W / 250W | 63MiB / 12186MiB | 0% Default

Viewing logs

The simplest way to view a log is via the StatusHub web UI or through the Android Wear application. However, StatusHub also ships with some commands for reading and manipulating logs.

To dump the log for a given service, there is the sh-dump command:

$ sh-dump tweetdump
-rw-r--r--  1 alex  staff  12141949085 Jul  3 19:09 tweets.csv
-rw-r--r--  1 alex  staff  12142001648 Jul  3 19:09 tweets.csv
-rw-r--r--  1 alex  staff  12142061169 Jul  3 19:10 tweets.csv
-rw-r--r--  1 alex  staff  12142116283 Jul  3 19:10 tweets.csv

You can also use the sh-stream command to view the output of a service in real time, for example:

$ sh-stream tweetdump
-rw-r--r--  1 alex  staff  12141949085 Jul  3 19:09 tweets.csv
-rw-r--r--  1 alex  staff  12142001648 Jul  3 19:09 tweets.csv
-rw-r--r--  1 alex  staff  12142061169 Jul  3 19:10 tweets.csv
-rw-r--r--  1 alex  staff  12142116283 Jul  3 19:10 tweets.csv

My favorite tool, though, is sh-avg. Using sh-avg, you can compute the averages of numerical fields over the last several log messages. For example, to average the results from the “TrainEmbedding” service:

$ sh-avg TrainEmbedding
size 10: cost=0.108642
size 20: cost=0.108811
size 50: cost=0.108578

You can also specify a particular average size (i.e. the number of log records to average):

$ sh-avg TrainEmbedding 32
size 32: cost=0.108722

If you want to be able to quickly see averages from your phone or smartwatch, you can setup a job to log the averages of another job:

$ while (true) do sh-log EmbedAvg sh-avg TrainEmbedding 30; sleep 30; done

As you can see, StatusHub allows you to be a command-line ninja with magical logging powers.

Going crazy

Once you have basic building blocks like sh-log and sh-stream, the possibilities are boundless.

With a pipe-based IRC client like ii, you can push chat logs to StatusHub in one terminal command. This makes it easy to keep tabs on IRC activity, even from devices without an IRC client (e.g. a smartwatch).

You could also pipe sh-stream into ii in order to send log messages to someone on IRC. This might not seem useful, but it actually could be. For example, say you want to be notified when a process finishes running. You could run this in one terminal:

$ ./some_long_task; sh-log Notify echo 'Task done!'

And then in some other terminal, perhaps on some other machine, run something like this:

$ sh-stream Notify | send-irc-messages

Using StatusHub yourself

The StatusHub repository has official installation instructions, but I thought I’d give a gist here as well. There are really three parts to a successful StatusHub installation:

  1. An sh-server process running on some internet-accessible machine. All of your devices should be able to connect to this machine over HTTP/HTTPS, either through a reverse proxy or via port forwarding.
  2. A set of jobs that do logging via the sh-log command. To have sh-log go to the correct server, you will need to set some environment variables.
  3. One or more devices from which you will consume logs. These devices simply need a browser, but you can also install the StatusHub commands and setup your environment variables accordingly.