Large-Scale Vehicle Classification

As a hobby project over the past few weeks, I’ve been training neural networks to predict the prices of cars in photographs. Anybody with some experience in ML can probably guess what I did, at least on a high level: I scraped terabytes of data from the internet and trained a neural network on it. However, I ended up learning more than I expected to along this journey. For example, this was the first time I’ve observed clear and somewhat surprising positive transfer across prediction tasks. In particular, I found that predicting additional details about each car (e.g. make and model) improved the accuracy of price predictions. Additionally, I learned a few things about file systems, specifically EXT4, that threw me off guard. And finally, I used AI-powered image editing to uncover some interesting behaviors of my trained classifier.

This project started off as a tweet a few months ago. I can’t remember exactly why I wanted this, but it was probably one of the many times that some exotic car caught my eye in the Bay Area. After tweeting out this idea, I didn’t actually consider implementing it for at least a month.

You might be thinking, “But wait! The app can’t possibly tell you exactly how much a car is worth from just one photo—there are just too many unknowns. Does the car actually work? Does it smell like mold on the inside? Is there some hidden dent not visible in the photo?” These are all valid points, but even though there are many unknowns, it’s often apparent to a human observer when a car looks expensive or cheap. Our model should be at least as good: it should be able to tell you a ballpark estimate of the price, and maybe even estimate the uncertainty of its own predictions.

The run-of-the-mill approach for building this kind of app is to 1) download a lot of car images with known prices from the internet, and 2) train a neural network to input an image and predict a price. Some readers may expect step 1 to be boring, while others might expect the exact same thing of step 2. I expected both steps to be boring and straightforward, but I learned some surprising lessons along the way. However, if that still doesn’t sound interesting to you, you can always skip to the results section to watch me have fun with a trained model.

Curating a Dataset

So let’s download millions of images of vehicles with known prices! This shouldn’t be too hard, since there’s tons of used car websites out there. I chose to use Kelley Blue Book (KBB)—not for any particular reason, but just because I’d heard of it before. It turned out to be pretty straightforward to scrape KBB because every listing had a unique integer ID, and all the IDs were predictably distributed in a somewhat narrow range. I ran my scraper for about two weeks, gathering about 500K listings and a total of 10 million images (note that each listing can contain multiple images of a single vehicle). During this time, I noticed that I was running low on disk space, so I added image downsampling/compression to my download script to avoid saving needlessly high-resolution images.

Then came my first surprise: despite having more than enough disk space, I was hitting a mysterious “No space left on device” error when writing images to disk. I quickly noticed an odd pattern: I could create a new file, write into it, or copy existing files, but I could not create files with particular names. After some Googling, I found out that this was a limitation of EXT4 when creating directories with millions of files in them. In particular, the file system maintains a fixed-size hash table for a directory, and when a particular hash table bucket fills up, the driver returns a “No space left on device” error for filenames that would go into that bucket. The fix was to disable this hash table, which was somewhat trivial and only took a few minutes.

And voila, no more I/O errors! However, now opening files by name took a long time—sometimes on the order of a 0.1 seconds—presumably because the file system had to re-scan the directory to look up each file. When training models later on, this ended up slowing down training to the extent that the GPU was barely utilized. To mitigate this, I built my own hash table on top of the file system using a pretty common approach. In particular, every image’s filename was already a hexadecimal hash, so I sorted the files into sub-directories based on the first two characters of their name. This way, I essentially created a hash table with 256 buckets, which seemed like enough to prevent the file system from being the bottleneck during scraping or data loading.

One thing I worried about was duplicate images in the dataset. For example, the same generic thumbnail image might be used every time a certain dealership lists a Nissan Altima for sale. While I’ve implemented fancy methods for dataset deduplication before (e.g. for DALL-E 2), I decided to go for a simpler approach to save compute and time. For each image, I computed a “perceptual hash” by downsampling the image to 16×16 and quantizing each color to a few bits, and then applied SHA256 to the quantized bitmap data. I then deleted all images whose exact hashes were repeated more than once. This ended up clearing out about 10% of the scraped images.

Once I had downloaded and deduplicated the dataset, I went through some of the images and saw that there was a lot of junk. By “junk”, I mean images that did not seem particularly useful for the task at hand. We want to classify photos of whole vehicles, not close-ups of wheels, dashboards, keys, etc. To remove these sorts of images from the dataset, I hand-labeled a few hundred good and bad images, and trained an SVM on top of CLIP features on this tiny dataset. I tuned the threshold of the SVM to have a false-negative rate under 1% to make sure almost all good (positive) data was kept in the dataset (at the expense of leaving in some bad data). This filtering ended up deleting about another 50% of the images.

And with that, dataset curation was almost done. Notably, I scraped much more metadata than just prices and images. I also dumped the text description, year, make/model, odometer reading, colors, engine type, etc. Also, I created a few plots to understand how the data was distributed:

A histogram of vehicle prices from the dataset. I was surprised to find that around 20% of listings were over $50k and around 6% were over $70k. I expected these high prices to be more rare.
Make / modelFord F150Chevrolet SilveradoRAM 1500Jeep WranglerFord ExplorerNissan Rogue
% of the dataset3.75%3.41%2.11%1.88%1.69%1.64%
The most prevalent make/models in the dataset. Just the top 6 make up almost 15% of all the vehicles.

Training a Model

Now that we have the dataset, it’s time to train a model! There are two more ingredients we need before we can start burning our GPUs: a training objective and a model architecture. For the training objective, I decided to frame the problem as a multi-class classification problem, and optimized the cross-entropy loss. In particular, instead of predicting an exact numerical price, the model predicts the probability that the price falls in a pre-defined set of ranges (e.g. the model can tell you “there is a 30% chance the price is between $10,000 and $15,000”). This setup forces the model to predict a probability distribution rather than just a single number. Among other things, this can help show how confident the model is in its prediction.

I tried training two different model architectures, both fine-tuned from pre-trained checkpoints. To start off strong with a near state-of-the-art model, I tried fine-tuning CLIP ViT-B/16. For a more nimble option, I also fine-tuned a MobileNetV2 that was pre-trained on ImageNet1K. Unlike the CLIP model, the MobileNetV2 is tiny (only a few megabytes) and runs very fast—even on a laptop CPU. I liked the idea of this model not only because it trained faster, but also because it would be easier to incorporate into an app or serve cheaply on a website. I did all of my training runs on my home PC, which has a single Titan X GPU with 12GB of VRAM.

In addition to the price range classification task, I also tried adding some auxiliary prediction tasks to the model. First, I added a separate linear output layer to estimate the median price as a single numerical value (to force the model to estimate the median and not the mean, I used the L1 loss). I also added an output layer for the make/model of the vehicle. Instead of predicting make and model independently, I treated the make/model pair as a class label. I kept 512 classes for this task, since this covered 98.5% of all vehicles, and added an additional “Unknown” class for the remaining listings. I also added an output layer for the manufacture year (as another multi-class task), since age can play a large role in the price of a car.

I expected the auxiliary prediction tasks to hurt performance on the main price range prediction task. After all, the extra tasks give the model more work to do, so it should struggle more with each individual task. To my surprise, this was not the case. When I added all of these auxiliary prediction tasks, the accuracy and cross-entropy loss for price range prediction actually improved faster and seemed to be converging to better values. This still leaves the question: which auxiliary losses contribute to the positive transfer? One data point I have is a buggy run, where I accidentally scaled the median price prediction layer incorrectly such that it was effectively unused. Even for this run, the positive transfer can be observed from the loss curves, indicating that the positive transfer mostly comes from predicting the make/model and year.

I’m not quite sure how to explain this surprising positive transfer. Perhaps prices are a very noisy signal, and adding more predictable variables helps learn more relevant features. Or perhaps nothing deep is going on at all, and adding more tasks is somehow equivalent to increasing the batch size or learning rate (these are two important hyperparameters that I did not have compute to tune). Regardless of the reason, having a bunch of auxiliary predictions is useful in and of itself, and can make the output of the model easier to interpret.

Looking at the above loss curves, you may be concerned that the accuracy is quite low (around 50% for the best checkpoint). However, it’s difficult to know if this is actually good or bad. Perhaps there is simply not enough information in single photos to infer the exact price of a car. One observation in support of this hypothesis is that the cross-entropy loss for make/model prediction was actually lower (around 0.5 nats) than the price range cross-entropy loss (around 1.2 nats). This means that, even though there are almost two orders of magnitude more make/model classes than price ranges, predicting the exact make/model is much easier than predicting the price. This makes sense: an image will usually be enough to tell what kind of car you are looking at, but won’t contain all of the hidden information (e.g. mileage) that you’d need to determine how expensive the car is.

Another thing you might have noticed from the loss curves is that none of these models have converged. This is not for any particularly good reason, except that I wanted to write this blog post before the end of the winter holidays. I will likely continue to train my best models until convergence, and may or may not update this post once I do.

Results

In this section, I will explore the capabilities and limitations of my smaller MobileNetV2-based model. While this model has worse accuracy than the fine-tuned CLIP model, it is much cheaper to run, and is likely what I would deploy if I turned this project into a real app. Overall, I was surprised how accurate and robust this small model was, and I had a lot of fun exploring it.

Starting off strong, I tested the model on photos of cars that I found in my camera roll. For at least three of these cars, I believe the make/model predictions are correct, and for one car I’m not sure what the correct answer should be. It’s interesting to note how well the model seems to work even for cars with unusual colors and patterns, which tend to dominate my camera roll.

Model predictions for cars that I have taken pictures of in the past.
Model predictions for cars that I found in my camera roll.

Of course, my personal photos aren’t very representative of what cars are out there. Let’s mix things up a bit by creating synthetic images of cars using DALL-E 2. I found it helpful to append “parked on the street” to the prompts to get a wider shot of each car. To me, all of the price predictions seem to make sense. Impressively, the model correctly predicts a “Tesla Model S” for the DALL-E 2 generation of a Tesla. The model also seems to predict that the “cheap car” is old.

Model predictions for images created by DALL-E 2. The prompts, in order, were: 1) "a sports car parked on the street", 2) "a cheap car parked on the street", 3) "a tesla car parked on the street", 4) "a fancy car parked on the street".
Model predictions for images created by DALL-E 2. The prompts, in order, were: 1) “a sports car parked on the street”, 2) “a cheap car parked on the street”, 3) “a tesla car parked on the street”, 4) “a fancy car parked on the street”.

So here’s a question: is the model just looking at the car itself, or is it looking at the surrounding context for more clues? For example, a car might be more likely to be expensive if it’s in a suburban neighborhood than if it seems to be in a shady abandoned lot. We can use DALL-E 2 “Edits” to evaluate exactly this. Here, I’ve taken a real photo of a car from my camera roll, used DALL-E 2 to remove the license plate, and then changed the background in various ways using another editing step:

Model predictions when using DALL-E 2 to edit the scene behind a car without modifying the car itself. The model believes the car is more expensive when it is parked in a suburban neighborhood than when it is parked in an empty lot or next to row homes.
Model predictions when using DALL-E 2 to edit the scene behind a car without modifying the car itself. The model believes the car is more expensive when it is parked in a suburban neighborhood than when it is parked in an empty lot or next to row homes.

And voila! It appears that, even though the model predicts the same make/model for all of the images, the background can influence the predicted price by almost $10k! After seeing this result, I suddenly found new appreciation for what can be studied using AI tools to edit images. With these tools, it is easy to conduct intervention studies where some part of an image is changed in interpretable ways. This seems like a really neat way to probe small image classification models, and I wonder if anybody else is doing it.

Here’s another question: is the model relying solely on the car logo to predict the make/model, or is it doing something more general? To study this, I took another photo of a car, edited the license plate, and then repeatedly re-generated different logos using DALL-E 2. The model appears to predict that the car is an Audi in every case, even though the logo is only a recognizable Audi logo in the first image.

Model predictions when editing the logo on the back of a car. The model predicts the same make/model despite the logo being modified.
Model predictions when editing the logo on the back of a car. The model predicts the same make/model despite the logo being modified.

For fun, let’s try one more experiment with DALL-E 2, where we generate out-of-distribution images of “cars”:

Model predictions from unusual-looking AI-generated cars.
Model predictions from unusual-looking AI-generated cars.

Happily, the model does not confidently claim to understand what kind of car these are. The price and year estimates are interesting, but I’m not sure how much to read into them.

In some of my earlier examples, the model correctly predicts the make/model of cars that are only partially visible in the photo. To study how far this can be pushed, I panned a crop along a side-view of two different cars to see how the model’s predictions changed as different parts of the car became visible. In these two examples, the model was most accurate when viewing the front or back of the car, but not when only the middle of the car was visible. Perhaps this is a shortcoming of the model, or perhaps automakers customize the front and back shapes of their car more than the sides. I’d be happy to hear other hypotheses as well!

Model predictions when panning a square crop over a long view of some cars.
Model predictions when panning a square crop over a long view of some cars.

Conclusion

In this post, I took you along with me as I scraped millions of online vehicle listings and trained models on the resulting data. In the process, I observed an unusual phenomenon where auxiliary losses actually improved the loss and accuracy of a classifier. Finally, I used AI-generated images to study the behavior of the resulting classifier, finding some interesting results.

After the fact, I realized that this project might be something others have already tried. After a brief online search, the closest project I found was this, which scraped only ~30k car listings (much smaller than my dataset of 500k). I also couldn’t find evidence of an image classifier trained on this data. I also found this paper which used a fairly small subset of the above dataset to predict the make/model out of a handful of classes; this still doesn’t seem particularly general or useful. After doing this research, I think my models might truly be the best or most general ones out there for what they do, but that wasn’t the main aim of the project.

The code for this project can be found on Github in the car-data repository. I also created a Gradio demo of the MobileNetV2 model, where you can upload your own images and see results.

Data and Machines

Suppose the year is 2000 BC and you want to set eyes upon a specific, non-naturally occurring color. To create this color, you must look for the right combination of plants, sand, soils, and clay to produce paint–and even then, there is no guarantee you will succeed exactly. Today, you can search “color picker” on the internet and find a six-digit hex code representing the color in question. Not only can your phone display the color corresponding to the hex code–you can also send this hex code to anybody on Earth and they will see nearly the exact same color. You can just as easily manipulate the hex code, making it darker or lighter, inverting it, or making it “more green”. What’s more, you can use a camera to extract hex codes from natural images and manipulate these in exactly the same way as any other hex code.

What I’ve described is what happens when powerful machines meet powerful data structures. A data structure, in this case a hex code, is useless on its own. You could easily write down “#65bcd4” 200 years ago, but it would not be usable like it is today. The opposite is also true: it would be very difficult to create computer screens or cameras without defining an abstract data structure to represent color. The benefit of data structures is that they give us–and machines–an easy way to manipulate and duplicate state from the physical world.

For most of recorded history, humans were the main machines that operated on data structures. Language is perhaps the biggest example of structured data that is arguably useless on its own, but incredibly powerful when paired with the right machines. Another example is music notation (with the accompanying abstractions such as notes and units of time), where humans (often paired with instruments) are the decoding machines. However, things are changing, and non-human machines can now read, write, and manipulate many forms of data that were once exclusively human endeavors.

The right combination of machines and data structures can be truly revolutionary. For example, the ability to capture images and videos from the world, manipulate them in abstract form within a computer program, and then display the resulting imagery on a screen has transformed how people live their lives.

A good data structure can even shape how people think about the physical world. For example, defining colors in terms of three basis components should not obviously work, and requires knowledge of how the human eye perceives light. But now that we have done this research, any human who learns about the XYZ color space will have a clear picture of exactly which human-perceivable colors can be created and how they relate to one another.

It’s worth noting that data structures paired with powerful machines are not always revolutionary, at least not right away. What if, for example, we tried to do the same thing with smell as we did for color: create a data structure for representing smells using base components, and then build machines for producing and detecting smells using our new data structure. Well, this has been tried many times, but has yet to break into the average person’s household. The reasons to me are a bit unclear, but it doesn’t appear to be a pure machine or data structure bottleneck. Sure, we cannot build the right machines, but we also don’t understand the human olfactory system enough to define a perfect smell data structure.

There are also plenty of examples of existing–although not rigorously defined–data structures that could be brought to life with the right machines. Imagine food 3D printers that can follow a human-written recipe, or machines that apply a full face of makeup to match a photograph. I think millions of people would spend money to download a file and immediately have a face of makeup that looks exactly like Lady Gaga on the red carpet. And these same customers would probably also love to be able to tell Alexa to add more salt to last night’s dinner and print it again next weekend. Here, I think the main bottleneck is that we simply don’t have the machines yet; the data structures themselves are fun and possibly even easy to dream up.

I’d argue that, most of the time, we already have the data structures; the problem is that humans are the only machines that can operate on them. For these sorts of problems, machine learning and robotics may one day offer a reasonable solution. Humans are still the machines making recipes, putting on makeup, cutting hair, etc. The data structures we use for these tasks are often encoded in natural language or images, and putting them into physical form requires dextrous manipulation and intelligence. Even decoding these concepts from the world is sometimes out of reach of current machines (e.g. “describe the haircut in this photograph”). The last example also hints that machine learning may also make it easier to translate between different, largely equivalent data structures.

The end result of this mechanization will be truly amazing, and perhaps frightening. Imagine a world, not so long from now, where things we today consider “uncopyable” or “analog” become digital and easily manipulatable. These could include haircuts, food, physical objects, or even memories and personalities. This seems to be the natural conclusion of technological advancement, and I’m excited and horrified to witness some of it in my lifetime.

VQ-DRAW: A New Generative Model

Today I am extremely excited to release VQ-DRAW, a project which has been brewing for a few months now. I am really happy about my initial set of results, and I’m so pleased to be sharing this research with the world. The goal of this blog post is to describe, on a high-level, what VQ-DRAW does and how it works. For a more technical description, you can always check out the paper or the code.

In essence, VQ-DRAW is just a fancy compression algorithm. As it turns out, when a compression algorithm is good enough, decompressing random data produces realistic-looking results. For example, when VQ-DRAW is trained to compress pictures of faces into 75 bytes, this is what happens when it decompresses random data:

Samples from a VQ-DRAW model trained on the CelebA dataset.

Pretty cool, right? Of course, there’s plenty of other generative models out there that can dream up realistic looking images. VQ-DRAW has two properties that set it apart:

  1. The latent codes in VQ-DRAW are discrete, allowing direct application to lossy compression without introducing extra overhead for an encoding scheme. Of course, other methods can produce discrete latent codes as well, but most of these methods either don’t produce high-quality samples, or require additional layers of density modeling (e.g. PixelCNN) on top of the latent codes.
  2. VQ-DRAW encodes and decodes data in a sequential manner. This means that reconstruction quality degrades gracefully as latent codes are truncated. This could be used, for example, to implement progressive loading for images and other kinds of data.

MNIST digits as more latent codes become available. The reconstructions gradually become crisper, which is ideal for progressive loading.

The core algorithm behind VQ-DRAW is actually quite simple. At each stage of encoding, a refinement network looks at the current reconstruction and proposes K variations to it. The variation with the best reconstruction error is chosen and passed along to the next stage. Thus, each stage refines the reconstruction by adding log2(K) bits of information to the latent code. When we run N stages of encoding, the latent code is thus N*log2(K) bits long.

VQ-DRAW encoding an MNIST digit in stages. In this simplified example, K = 8 and N = 3.

To train the refinement network, we can simply minimize the L2 distance between images in the training set and their reconstructions. It might seem surprising that this works. After all, why should the network learn to produce useful variations at each stage? Well, as a side-effect of selecting the best refinement options for each input, the refinement network is influenced by a vector quantization effect (the “VQ” in “VQ-DRAW”). This effect pulls different options towards different samples in the training set, causing a sort of clustering to take place. In the first stage, this literally means that the network clusters the training samples in a way similar to k-means.

With this simple idea and training procedure, we can train VQ-DRAW on any kind of data we want. However, I only experimented with image datasets for the paper (something I definitely plan to rectify in the future). Here are some image samples generated by VQ-DRAW:

Samples from VQ-DRAW models trained on four different datasets.

CIFAR digits being decoded stage-by-stage. Each stage adds a few bits of information.

At the time of this release, there is still a lot of work left to be done on VQ-DRAW. I want to see how well the method scales to larger image datasets like ImageNet. I also want to explore various ways of improving VQ-DRAW’s modeling power, which could make it generate even better samples. Finally, I’d like to try VQ-DRAW on various kinds of data beyond images, such as text and audio. Wish me luck!

Research Projects That Didn’t Pan Out

Anybody who does research knows that ideas often don’t pan out. However, it is fairly rare to see papers or blogs about negative results. This is unfortunate, since negative results can tell us just as much as positive ones, if not more.

Today, I want to share some of my recent negative results in machine learning research. I’ll also include links to the code for each project, and share some theories as to why each idea didn’t work.

Project 1: reptile-gen

Source code: https://github.com/unixpickle/reptile-gen

Premise: This idea came from thinking about the connection between meta-learning and sequence modeling. Research has shown that sequence modeling techniques, such as self-attention and temporal convolutions (or the combination of the two), can be used as effective meta-learners. I wondered if the reverse was also true: are effective meta-learners also good sequence models?

It turns out that many sequence modeling tasks, such as text and image generation, can be posed as meta-learning problems. This means that MAML and Reptile can theoretically solve these problems on top of nothing but a feedforward network. Instead of using explicit state transitions like an RNN, reptile-gen uses a feedforward network’s parameters as a hidden state, and uses SGD to update this hidden state. More details can be found in the README of the GitHub repository.

Experiments: I applied two meta-learning algorithms, MAML and Reptile, to sequence modeling tasks. I applied both of these algorithms on top of several different feedforward models (the best one was a feedforward network that resembles one step of an LSTM). I tried two tasks: generating sequences of characters, and generating MNIST digits pixel-by-pixel.

Results: I never ended up getting high-quality samples from any of these experiments. For MNIST, I got things that looked digit-esque, but the samples were always very distorted, and my LSTM baseline converged to much better solutions with much less tuning. I did notice a big gap in performance between MAML and Reptile, with MAML consistently winning out. I also noticed that architecture mattered a lot, with the LSTM-like model performing better than a vanilla MLP. Gated activations also seemed to boost the performance of the MLPs, although not by much.

MNIST samples from reptile-gen, trained with MAML.
MNIST samples from reptile-gen, trained with MAML.

Takeaways: My main takeaway from this project was that MAML is truly better than Reptile. Reptile doesn’t back-propagate through the inner-loop, and as a result it seems to have much more trouble modeling long sequences. This is in contrast to our findings in the original Reptile paper, where Reptile performed about as well as MAML. How could this be the case? Well, in that paper, we were testing Reptile and MAML with small inner-loops consisting of less than 100 samples; in this experiment, the MNIST inner-loop had 784 samples, and the inputs were (x,y) indices (which inherently share very little information, unlike similar images).

While working on this project, I went through a few different implementations of MAML. My first implementation was very easy to use without modifying the PyTorch model at all; I didn’t expect such a plug-and-play implementation to be possible. This made my mind much more open to MAML as an algorithm in general, and I’d be very willing to use it in future projects.

Another takeaway is that sequence modeling is hard. We should feel grateful for Transformers, LSTMs, and the like. There are plenty of architectures which ought to be able to model sequences, but fail to capture long-term dependencies in practice.

Project 2: seqtree

Source code: https://github.com/unixpickle/seqtree

Premise: As I’ve demonstrated before, I am fascinated by decision tree learning algorithms. Ensembles of decision trees are powerful function approximators, and it’s theoretically simple to apply them to a diverse range of tasks. But can they be used as effective sequence models? Naturally, I wondered if decision trees could be used to model the sequences that reptile-gen failed to. I also wanted to experiment with something I called “feature cascading”, where leaves of some trees in an ensemble could generate features for future trees in the ensemble.

Experiments: Like for reptile-gen, I tried two tasks: MNIST digit generation, and text generation. I tried two different approaches for MNIST digit generation: a position-invariant model, and a position-aware model. In the position-invariant model, a single ensemble of trees looks at a window of pixels above and to the left of the current pixel, and tries to predict the current pixel; in the position-aware model, there is a separate ensemble for each location in the image, each of which can look at all of the previous pixels. For text generation, I only used a position-invariant model.

Results: The position invariant models underfit drastically. For MNIST, they generated chunky, skewed digits. The position-aware model was on the other side of the spectrum, overfitting drastically to the training set after only a few trees in each ensemble. My feature cascading idea was unhelpful, and greatly hindered runtime performance since the feature space grew rapidly with training.

Takeaways: Decision tree ensembles simply can’t do certain things well. There are two possible reasons for this: 1) there is no way to build hierarchical representations with linearly-combined ensembles of trees; 2) the greedy nature of tree building prevents complex relationships from being modeled properly, and makes it difficult to perform complex computations.

Another more subtle realization was that decision tree training is not very easy to scale. With neural networks, it’s always possible to add more neurons and put more machines in your cluster. With trees, on the other hand, there’s no obvious knob to turn to throw more compute at the problem and consistently get better results. Tree building algorithms themselves are also somewhat harder to parallelize, since they rely on fewer batched operations. I had trouble getting full CPU utilization, even on a single 64-core cloud instance.

Project 3: pca-compress

Source code: https://github.com/unixpickle/pca-compress

Premise: Neural networks contain a lot of redundancy. In many cases, it is possible to match a network’s accuracy with a much smaller, carefully pruned network. However, there are some caveats that make this fact hard to exploit. First of all, it is difficult to train sparse networks from scratch, so sparsity does not help much to accelerate training. Furthermore, the best sparsity results seem to involve unstructured sparsity, i.e. arbitrary sparsity masks that are hard to implement efficiently on modern hardware.

I wanted to find a pruning method that could be applied quickly, ideally before training, that would also be efficient on modern hardware. To do this, I tried a form of rank-reduction where the linear layers (i.e. convolutional and fully-connected layers) were compressed without affecting the final number of activations coming out of each layer. I wanted to do this rank-reduction in a data-aware way, allowing it to exploit redundancy and structure in the data (and in the activations of the network while processing the data). The README of the GitHub repository includes a much more detailed description of the exact algorithms I tried.

Experiments: I tried small-scale MNIST experiments and medium-scale ImageNet experiments. I call the latter “medium-scale” because I reserve “large-scale” for things like GPT-2 and BERT, both of which are out of reach for my compute. I tried pruning before and after training. For ImageNet, I also tried iteratively pruning and re-training. A lot of these experiments were motivated by the lottery ticket hypothesis, which stipulates that it may be possible to train sparse networks from scratch with the right set of initial parameters.

I tried several methods of rank-reduction. The simplest, which was based on PCA, only looked at the statistics of activations and knew nothing about the optimization objective. The more complex approaches, which I call “output-aware”, considered both inputs and outputs, trying to prevent the network’s output from changing too much after pruning.

Results: For my MNIST baseline, I was able to prune networks considerably (upwards of 80%) without any significant loss in accuracy. I also found that an output-aware pruning method was better than the simple PCA baseline. However, results and comparisons on this small baseline did not accurately predict ImageNet results.

On ImageNet, PCA pruning was uniformly the best approach. A pre-trained ResNet-18 pruned with PCA to 50% rank across all convolutional layers experienced a 10% decrease in top-1 performance before any tuning or re-training. With iterative re-training, this gap was reduced to closer to 1.8%, which is still worse than the state-of-the-art. My output-aware pruning methods resulted in a performance gap closer to 30% before re-training (much worse than PCA).

I never got around to experimenting with larger architectures (e.g. ResNet-50) and more severe levels of sparsity (e.g. 90%). Some day I may revisit this project and try such things, but at the moment I simply don’t have the compute available to run these experiments.

At the end of the day, why would I look at these results and consider this a research project that “didn’t pan out”? Mostly because I never managed to get the reduction in compute that I was hoping to achieve. My hope was that I could prune networks without a ton of compute-heavy re-training. My grand ambition was to figure out how to prune networks at or near initialization, but this did not pan out either. My only decent results were with iterative pruning, which is computationally expensive and defeats most of the purpose of the exercise. Perhaps my pruning approach could be used to find good, production ready models, but it cannot be used (as far as I know) to speed up training time.

Takeaways: One takeaway is that results on small-scale experiments don’t always carry over to larger experiments. I developed a bunch of fancy pruning algorithms which worked well on MNIST, but none of them beat the PCA baseline on ImageNet. I never quite figured out why PCA pruning worked the best on ImageNet, but my working theory is that the L2 penalty used during training resulted in activations that had little-to-no variance in discriminatively “unimportant” directions.

Competing in the Obstacle Tower Challenge

I had a lot of fun competing in the Unity Obstacle Tower Challenge. I was at the top of the leaderboard for the majority of the competition, and for the entirety of Round 2. By the end of the competition, my agent was ranked at an average floor of 19.4, greater than the human baseline (15.6) in Juliani et al., and greater than my own personal average performance. This submission outranked all of the other submissions, by a very large margin in most cases.

So how did I do it? The simple answer is that I used human demonstrations in a clever way. There were a handful of other tricks involved as well, and this post will briefly touch on all of them. But first, I want to take a step back and describe how I arrived at my final solution.

Before I looked at the Obstacle Tower environment itself, I assumed that generalization would be the main bottleneck of the competition. This assumption mostly stemmed from my experience creating baselines for the OpenAI Retro Contest, where every model generalized terribly. As such, I started by trying a few primitive solutions that would inject as little information into the model as possible. These solutions included:

  • Evolving a policy with CMA-ES
  • Using PPO to train a policy that looked at tiny observations (e.g. 5×5 images)
  • Using CEM to learn an open-loop action distribution that maximized rewards.

However, none of these solutions reached the fifth floor, and I quickly realized that a PPO baseline did better. Once I started tuning PPO, it quickly became clear that generalization was not the actual bottleneck. It turned out that the 100 training seeds were enough for standard RL algorithms to generalize fairly well. So, instead of focusing on generalization, I simply aimed to make progress on the training set (with a few exceptions, e.g. data augmentation).

My early PPO implementation, which was based on anyrl-py, was hitting a wall at the 10th floor of the environment. It never passed the 10th floor, even by chance, indicating that the environment posed too much of an exploration problem for standard RL. This was when I decided to take a closer look at the environment to see what was going on. It turned out that the 10th floor marked the introduction of the Sokoban puzzle, where the agent must push a block across a room to a square target marked on the floor. This involves taking a consistent set of actions for several seconds (on the order of ~50 timesteps). So well for traditional RL.

At this point, other researchers might have tried something like Curiosity-driven Exploration or Go-Explore. I didn’t even give these methods the time of day. As far as I can tell, these methods all have a strong inductive bias towards visually simple (often 2-dimensional) environments. Exploration is extremely easy with visually simple observations, and even simple image similarity metrics can be used with great success in these environments. On Obstacle Tower, however, observations depend completely on what the camera is pointed at, where the agent is standing in a room, etc. The agent can see two totally different images while standing in the same spot, and it can see two very similar images while standing in two totally different rooms. Moreover, the first instant of pushing a box for the Sokoban puzzle looks very similar to the final moment of pushing the same box. My hypothesis, then, was that traditional exploration algorithms would not be very effective in Obstacle Tower.

If popular exploration algorithms are out, how do we make the agent solve the Sokoban puzzle? There are two approaches I would typically try here: evolutionary algorithms, which explore in parameter space rather than algorithm space, and human demonstrations, which bypass the problem of exploration altogether. With my limited compute capabilities (one machine with one GPU), I decided that human demonstrations would be more practical, since evolution typically burns through a lot of compute to train neural networks on games.

To start myself off with human demonstrations, I created a simple tool to record myself playing Obstacle Tower. After recording a few games, I used behavior cloning (supervised learning) to fit a policy to my demonstrations. Behavior cloning started overfitting very quickly, so I stopped training early and evaluated the resulting policy. It was terrible, but it did perform better than a random agent. I tried fine-tuning this policy with PPO, and was pleased to see that it learned faster than a policy trained from scratch. However, it did not solve the Sokoban puzzle.

Fast-forward a bit, and behavior cloning + fine-tuning still hadn’t broken through the Sokoban puzzle, even with many more demonstrations. At around this time, I rewrote my code in PyTorch so that I could try other imitation learning algorithms more easily. And while algorithms like GAIL did start to push boxes around, I hadn’t seen them reliably solve the 10th floor. I realized that the problem might involve memory, since the agent would often run around in circles doing redundant things, and it had no way of remembering if it had just seen a box or a target.

So, how did I fix the agent’s memory problem? In my experience, recurrent neural networks in RL often don’t remember what you want them to, and they take a huge number of samples to learn to remember anything useful at all. So, instead of using a recurrent neural network to help my agent remember the past, I created a state representation that I could stack up for the past 50 timesteps and then feed to my agent as part of its input. Originally, the state representation was a tuple of (action, reward, has key) values. Even with this simple state representation, behavior cloning worked way better (the test loss reached a much lower point), and the cloned agent had a much better initial score. But I didn’t stop there, because the state representation still said nothing about boxes or targets.

To help the agent remember things that could be useful for solving the Sokoban puzzle, I trained a classifier to identify common objects like boxes, doors, box targets, keys, etc. I then added these classification outputs to the state tuple. This improved behavior cloning even more, and I started to see the behavior cloned agent solve the Sokoban puzzle fairly regularly.

Despite the agent’s improved memory, behavior cloning + fine-tuning was still failing to solve the Sokoban puzzle, and GAIL wasn’t much of an improvement. It seemed that, by the time the agent started reaching the 10th floor, it had totally forgotten how to push boxes! In my experience, this kind of forgetting in RL is often caused by the entropy bonus, which encourages the agent to take random actions as much as possible. This bonus tends to destroy parts of an agent’s pre-trained behavior that do not yield noticeable rewards right away.

This was about the time that prierarchy came in. In addition to my observation about the entropy bonus destroying the agent’s behavior, I also noticed that the behavior cloned agent took reasonable low-level actions, but it did so in ways that were unreasonable in a high-level context. For example, it might push a box all the way to the corner of a room, but it might be the wrong corner. Instead of using an entropy bonus, I wanted a bonus that would keep these low-level actions in tact, while allowing the agent to solve the high-level problems that it was struggling with. This is when I implemented the KL term that makes prierarchy what it is, with the behavior cloned policy as the prior.

Once prierarchy was in place, things were pretty much smooth sailing. At this point, it was mostly a matter of recording some more demonstrations and training the agent for longer (on the order of 500M timesteps). However, there were still a few other tricks that I used to improve performance:

  • My actual submissions consisted of two agents. The first one solved floors 1-9, and the second one solved floors 10 and onward. The latter agent was trained with starting floors sampled randomly between 10 and 15. This forced it to learn to solve the Sokoban puzzle immediately, rather than perfecting floors 1-9 first.
  • I used a reduced action space, mostly because I found that it made it easier for me to play the game as a human.
  • My models were based on the CNN architecture from the IMPALA paper. In my experience with RL on video games, this architecture learns and generalizes better than the architecture used in the original Nature article.
  • I used Fixup initialization to help train deeper models.
  • I used MixMatch to train the state classifier with fewer labeled examples than I would have needed otherwise.
  • For behavior cloning, I used traditional types of image data augmentation. However, I also used a mirroring data augmentation where images and actions were mirrored together. This way, I could effectively double the number of training levels, since every level came with its mirror image as well.
  • During prierarchy training, I applied data augmentation to the Obstacle Tower environment to help with overfitting. I never actually verified that this was necessary, and it might not have been, but other contestants definitely struggled with overfitting more than I did.
  • I added a small reward bonus for picking up time orbs. It’s unclear how much of an effect this had, since the agent still missed most of the time orbs. This is one area where improvement would definitely result in a better agent.

I basically checked out for the majority of Round 2: I stopped actively working on the contest, and for a lot of the time I wasn’t even training anything for it. Near the end of the contest, when other contestants started solving the Sokoban puzzle, I trained my agent a little bit more and submitted the new version, but it turned out not to have been necessary.

My code can be found on Github. I do not intend on changing the repository much at this point, since I want it to remain a reflection of the solution described in this post.

Prierarchy: Implicit Hierarchies

On first blush, hierarchical reinforcement learning seems like the holy grail of artificial intelligence. Ideally, it performs long-term exploration, learns over long time-horizons, and has many other benefits. Furthermore, HRL makes it possible to reason about an agent’s behavior in terms of high-level symbols. However, in my opinion, pure HRL is too rigid and doesn’t capture how humans actually operate hierarchically.

Today, I’ll propose an idea that I call “the prierarchy” (a combination of “prior” and “hierarchy”). This idea is an alternative way to look at HRL, without the need to define rigid action boundaries or discrete levels of hierarchy. I am using the prierarchy in the Unity Obstacle Tower Challenge, so I will keep a few details close to my chest for now. However, I intend to update this post once the contest is completed.

Defining the Prierarchy

To understand the prierarchy framework, let’s consider character-level language models. When you run a language model on a sentence, there are some parts of the sentence where the model is very certain what the next token should be (low-entropy), and other parts where it is very uncertain (high-entropy). We can think of this as follows: the high-entropy parts of the sentence mark the starts of high-level actions, while the low-entropy parts represent the execution of those high-level actions. For example, the model likely has more entropy at the starts of words or phrases, while it has less entropy near the ends of words (especially long, unique words). As a consequence of this view, we can say that prompting a language model is the same thing as starting a high-level action and seeing how the model executes this high-level action.

Under the prierarchy framework, we initiate “high-level actions” by taking a sequence of low-probability, low-level actions which kick off a long sequence of high-probability, low-level actions. This assumes that we have some kind of prior distribution over low-level actions, possibly conditioned on observations from the environment. This prior basically encodes the lower-levels of the hierarchy that we want to control with a high-level policy.

The nice thing about the prierarchy is that we never had to make any specific assumptions about action boundaries, since the “high-level actions” aren’t real or concrete, they are just our interpretation of a sequence of continuous entropy values. This means, for one thing, that no low- or high-level policy has to worry about stopping conditions.

Practical Application

In order to implement the prierarchy, you first need a prior. A prior, in this case, is a distribution over low-level actions conditioned on previous observations. This is exactly a “policy” in RL. Thus, you can get a prior in any way you can get a policy: you can train a prior on a different (perhaps more dense but misaligned) reward function; you can train a prior via behavior cloning; you can even hand-craft a prior using some kind of approximate solution to your problem. Whatever method you take, you should make sure that the prior is capable of performing useful behaviors, even if these behaviors occur in the wrong order, for the wrong duration of time, etc.

Once you have a prior, you can train a controller policy fairly easily. First, initialize the policy with the prior, and then perform a policy gradient algorithm with a KL regularizer KL(policy|prior) instead of an entropy bonus. This algorithm is basically saying: “do what the prior says, and inject information into it when you need to in order to achieve rewards”. Note that, if the prior is the uniform distribution over actions, then this is exactly equivalent to traditional policy gradient algorithms.

If the prior is low entropy, then you should be able to significantly increase the discount factor. This is because the noise of policy gradient algorithms scales with the amount of information that is injected into the trajectories by sampling from the policy. Also, if you select a reasonable prior, then the prior is likely to explore much better than a random agent. This can be useful in very sparse-reward environments.

Let’s look at a simple (if not contrived) example. In RL, frame-skip is used to help agents take more consistent actions and thus explore better. It also helps for credit assignment, since the policy makes fewer decisions per span of time in the environment. It should be clear that frame-skip defines a prior distribution over action sequences, where every Nth action completely determines the next N-1 actions. My claim, which I won’t demonstrate today, is that you can achieve a similar effect by pre-training a recurrent policy to mimic the frame-skip action distribution, and then applying a policy gradient algorithm to this policy with an appropriately modified discount factor (e.g. 0.99 might become 0.990.25) and a KL regularizer against the frame-skip prior. Of course, it would be silly to do this in practice, since the prior is so easy to bake into the environment.

There are tons of other potential applications of this idea. For example, you could learn a better action distribution via an evolutionary algorithm like CMA-ES, and then use this as the prior. This way, the long-term exploration benefits of Evolution Strategies could be combined with the short-term exploitation benefits of policy gradient methods. One could also imagine learning a policy that controls a language model prior to write popular tweets, or one that controls a self-driving car prior for the purposes of navigation.

Conclusion

The main benefit of HRL is really just that it compresses sequences of low-level actions, resulting in better exploration and less noise in the high-level policy gradient. That same compression can be achieved with a good low-level prior, without the need to define explicit action boundaries.

As a final note, the prierarchy training algorithm looks almost identical to pre-training + fine-tuning, indicating that it’s not a very special idea at all. Nonetheless, it seems like a powerful one, and perhaps it will save us from wasting a lot of time on HRL.

Solving murder with Go

I recently saw a blog post entitled Solving murder with Prolog. It was a good read, and I recommend it. One interesting aspect of the solution is that it reads somewhat like the original problem. Essentially, the author translated the clues into Prolog, and got a solution out automatically.

This is actually an essential skill for programmers in general: good code should be readable, almost like plain English—and not just if you’re programming in Prolog. So today, I want to re-solve this problem in Go. As I go, I’ll explain how I optimized for readability. For the eager, here is the final solution.

Constants are very useful for readability. Instead of having a magic number sitting around, try to use an intelligible word or phrase! Let’s create some constants for rooms, people, and weapons:

const (
	Bathroom = iota
	DiningRoom
	Kitchen
	LivingRoom
	Pantry
	Study
)

const (
	George = iota
	John
	Robert
	Barbara
	Christine
	Yolanda
)

const (
	Bag = iota
	Firearm
	Gas
	Knife
	Poison
	Rope
)

The way I am solving the puzzle is as follows: we generate each possible scenario, check if it meets all the clues, and if so, print it out. To do this, we can store the scenario in a structure. Here’s how I chose to represent this structure, which I call Configuration:

type Configuration struct {
	People  []int
	Weapons []int
}

Here, we represent the solution by storing the person and weapon for each room. For example, to see the person in the living room, we check cfg.People[LivingRoom]. To see the weapon in the living room, we do cfg.Weapons[LivingRoom]. This will make it easy to implement clues in a readable fashion. For example, to verify the final clue, we simply check that cfg.Weapons[Pantry] == Gas, which reads just like the original clue.

Next, we need some way to iterate over candidate solutions. Really, we just want to iterate over possible permutations of people and weapons. Luckily, I have already made a library function to generate all permutations of N indices. Using this API, we can enumerate all configurations like so:

for people := range approb.Perms(6) {
	for weapons := range approb.Perms(6) {
		cfg := &Configuration{people, weapons}
		// Check configuration here.
	}
}

Note how this reads exactly like what we really want to do. We are looping through permutations of people and weapons, and we can see that easily. This is because we moved all the complex permutation generation logic into its own function; it’s abstracted away in a well-named container.

Now we just have to filter the configurations using all the available clues. We can make a function that checks if each clue is satisfied, giving us this main loop:

for people := range approb.Perms(6) {
	for weapons := range approb.Perms(6) {
		cfg := &Configuration{people, weapons}
		if !clue1(cfg) || !clue2(cfg) || !clue3(cfg) ||
			!clue4(cfg) || !clue5(cfg) ||
			!clue6(cfg) || !clue7(cfg) ||
			!clue8(cfg) || !clue9(cfg) {
			continue
		}
		fmt.Println("killer is:", cfg.People[Pantry])
	}
}

Next, before we implement the actual clues, let’s implement some helper functions. Well-named helper functions greatly improve readability. Note in particular the isMan helper. It’s very clear what it does, and it abstracts away an arbitrary-seeming 3.

func isMan(person int) bool {
	return person < 3
}

func indexOf(l []int, x int) int {
	for i, y := range l {
		if y == x {
			return i
		}
	}
	panic("unreachable")
}

Finally, we can implement all the clues in a human-readable fashion. Here is what they look like:

func clue1(cfg *Configuration) bool {
	return isMan(cfg.People[Kitchen]) && (cfg.Weapons[Kitchen] == Knife || cfg.Weapons[Kitchen] == Gas)
}

func clue2(cfg *Configuration) bool {
	if cfg.People[Study] == Barbara {
		return cfg.People[Bathroom] == Yolanda
	} else if cfg.People[Study] == Yolanda {
		return cfg.People[Bathroom] == Barbara
	} else {
		return false
	}
}

func clue3(cfg *Configuration) bool {
	bagRoom := indexOf(cfg.Weapons, Bag)
	if cfg.People[bagRoom] == Barbara || cfg.People[bagRoom] == George {
		return false
	}
	return bagRoom != Bathroom && bagRoom != DiningRoom
}

func clue4(cfg *Configuration) bool {
	if cfg.Weapons[Study] != Rope {
		return false
	}
	return !isMan(cfg.People[Study])
}

func clue5(cfg *Configuration) bool {
	return cfg.People[LivingRoom] == John || cfg.People[LivingRoom] == George
}

func clue6(cfg *Configuration) bool {
	return cfg.Weapons[DiningRoom] != Knife
}

func clue7(cfg *Configuration) bool {
	return cfg.People[Study] != Yolanda && cfg.People[Pantry] != Yolanda
}

func clue8(cfg *Configuration) bool {
	firearmRoom := indexOf(cfg.Weapons, Firearm)
	return cfg.People[firearmRoom] == George
}

func clue9(cfg *Configuration) bool {
	return cfg.Weapons[Pantry] == Gas
}

Most of these clues read almost like the original. I have to admit that the indexOf logic isn’t perfect. Let me know how/if you’d do it better!

What I Don’t Know

Starting this holiday season, I want to take some time every day to focus on a broader set of topics than just machine learning. While ML is extremely valuable, working on it day after day has given me tunnel vision. I think it’s important to remind myself that there’s more out there in the world of technology.

This post is going to be an aspirational one. I started by listing a ton of things I’m embarrassed to know nothing about. Then, in the spirit of self-improvement, I came up with a list of project ideas for each topic. I find that I learn best by doing, so I hope these projects will help me master areas where I have little or no prior experience. And who knows, maybe others will benefit from this list as well!

Networking

My knowledge of networking is severely lacking. First of all, I have no idea how the OS network stack works on either Linux or macOS. Also, I’ve never configured a complex network (e.g. for a datacenter), so I have a limited understanding of how routing works.

I am so ignorant when it comes to networking that I often struggle to formulate questions about it. Hopefully, my questions and project ideas actually make sense and turn out to be feasible.

Questions:

  • What APIs does your OS provide to intercept or manipulate network traffic?
  • What actually happens when you connect to a WiFi network?
  • What is a network interface? How is traffic routed through network interfaces?
  • How do VPNs work internally (both on the server and on the client)?
  • How flexible are Linux networking primitives?
  • How does iptables work on Linux? What’s it actually do?
  • How do NATs deal with different kinds of traffic (e.g. ICMP)?
  • How does DNS work? How do DNS records propagate? How do custom nameservers (e.g. with NS records) work? How does something like iodine work?

Projects:

  • Re-implement something like Little Snitch.
  • Try using the Berkeley Packet Filter (or some other API) to make a live bandwidth monitor.
  • Implement a packet-level WiFi client that gets all the way up to being able to make DNS queries. I started this with gofi and wifistack, but never finished.
  • Implement a program that exposes a fake LAN with a fake web server on some fake IP address. This will involve writing your own network stack.
  • Implement a user-space NAT that exposes a fake “gateway” through a tunnel interface.
  • Re-implement iodine in a way that parallelizes packet transmission to be faster on satellite internet connections (like on an airplane).
  • Re-implement something like ifconfig using system calls.
  • Connect your computer to both Ethernet and WiFi, and try to write a program that parallelizes an HTTP download over both interfaces simultaneously.
  • Write a script to DOS a VPN server by allocating a ton of IP addresses.
  • Write a simple VPN-like protocol and make a server/client for it.
  • Try to set something up where you can create a new Docker container and assign it its own IP address from a VPN.
  • Try to implement a simple firewall program that hooks into the same level of the network stack as iptables.
  • Implement a fake DNS server and setup your router to use it. The DNS server could forward most requests to a real DNS server, but provide fake addresses for specific domains of your choosing. This would be fun for silly pranks, or for logging domains people visit.
  • Try to bypass WiFi paywalls at hotels, airports, etc.

Cryptocurrency

Cryptocurrencies are extremely popular right now. So, as a tech nerd, I feel kind of lame knowing nothing about them. Maybe I should fix that!

Questions:

  • How do cryptocurrencies actually work?
  • What kinds of network protocols do cryptocurrencies use?
  • What does it mean that Ethereum is a distributed virtual machine?
  • What computations are actually involved in mining cryptocurrencies?

Projects:

  • Implement a toy cryptocurrency.
  • Write a script that, without using high-level APIs, transfers some cryptocurrency (e.g. Bitcoin) from one wallet to another.
  • Write a small program (e.g. “Hello World”) that runs on the Ethereum VM. I honestly don’t even know if this is possible.
  • Try writing a Bitcoin mining program from scratch.

Source Control

Before OpenAI, I worked mostly on solitary projects. As a result, I only used a small subset of the features offered by source control tools like Git. I never had to deal with complex merge conflicts, rebases, etc.

Questions:

  • What are some complicated use-cases for git rebase?
  • How do code reviews typically work on large open source projects?
  • What protocol does git use for remotes? Is a Github repository just a .git directory on a server, or is there more to it than that?
  • What are some common/useful git commands besides git push, git pull, git add, git commit, git merge, git remote, git checkout, git branch, and git rebase? Also, what are some unusual/useful flags for the aforementioned commands?
  • How do you actually set up an editor to work with git add -p?
  • How does git store data internally?

Projects:

  • Try to write a script that uses sockets/HTTPS/SSH to push code to a Github repo. Don’t use any local git commands or APIs.
  • On your own repos, intentionally get yourself into source control messes that you have to figure your way out of.
  • Submit more pull requests to big open source projects.
  • Read a ton of git man pages.
  • Write a program from scratch that converts tarballs to .git directories. The .git directory would represent a repository with a single commit that adds all the files from the tarball.

Machine Learning

Even though I work in ML, I sometimes forget to keep tabs on the field as a whole. Here’s some stuff that I feel I should brush up on:

Questions:

  • How do SOTA object detection systems work?
  • How do OCR systems deal with variable-length strings in images?
  • How do neural style transfer and CycleGAN actually work?

Projects:

  • Make a program that puts boxes around people’s noses in movies.
  • Make a captcha cracker.
  • Make a screenshot-to-text system (easy to generate training data!).
  • Try to make something that takes MNIST digits and makes them look like SVHN digits.

Phones

For something so ubiquitous, phones are still a mystery to me. I’m not sure how easy it is to learn about phones as a user hacking away in his apartment, but I can always try!

Questions:

  • What does the SMS protocol actually look like? How is the data transmitted? Why do different carriers seem to have different character limits?
  • How does modern telephony work? How are calls routed?
  • Why is it so easy to spoof a phone number, and how does one do it?
  • How are Android apps packaged, and how easy is it to reverse engineer them?

Projects:

  • Try reverse engineering SMS apps on your phone. Figure out what APIs deal with SMS messages, and how messages get from the UI all the way to the cellular antenna. Try to encode and send an SMS message from as low a level as possible.
  • Get a job at Verizon/AT&T/Sprint/T-Mobile. This is probably not worth it, but telephony is one of those topics that seem pretty hard to learn about from the outside.

Misc. Tools

I don’t take full advantage of many of the applications I use. I could probably get a decent productivity boost by simply learning more about these tools.

Questions:

  • How do you use ViM macros?
  • What useful keyboard shortcuts does Atom provide?
  • How do custom go get URLs work?
  • How are man pages formatted, and how can I write a good man page?

Projects:

  • Use a ViM macro to un-indent a chunk of lines by one space.
  • For a day, don’t let yourself use the mouse at all in your text editor. Lookup keyboard shortcuts all you want.
  • For a day, use your editor for all development (including running your code). Feasible with e.g. Hydrogen.
  • Make a go get service that allows for semantic versioning wildcards (e.g. go get myservice.org/github.com/unixpickle/somelib/0.1.x).
  • Write man pages for some existing open source projects. I bet everybody will thank you.

Adversarial Train/Test Splits

The MNIST dataset is split into 60K training images and 10K test images. Everyone pretty much assumes that the training images are enough to learn how to classify the test images, but this is not necessarily the case. So today, in the theme of being a contrarian, I’ll show you how to split the 70K MNIST images in an adversarial manner, such that models trained on the adversarial training set perform poorly on the adversarial test set.

Class imbalance is the most obvious way to make an adversarial split. If you put all the 2’s in the test set, then a model trained on the training set will miss all those 2’s. However, since this answer is fairly boring and doesn’t say very much, let’s assume for the rest of this post that we want class balance. In other words, we want the test set to contain 1K examples of each digit. This also means that, just by looking at the labels, you won’t be able to tell that the split is adversarial.

Here’s another idea: we can train a model on all 70K digits, then get the test set by picking the digits that have the worst loss. What we end up with is a test set that contains the “worst” 1K examples of each class. It turns out that this approach gives good results, even with fairly simple models.

When I tried the above idea with a small, single-layer MLP, the results were surprisingly good. The model, while extremely simple, gets 97% test accuracy on the standard MNIST split. On the adversarial split, however, it only gets 81% test accuracy. Additionally, if you actually look at the samples in the training set versus those in the test set, the difference is quite noticeable. The test set consists of a bunch of distorted, smudgy, or partially-drawn digits.

Example images from the adversarial train/test split.
Training samples (left), testing samples (right).

In a sense, the result I just described was obtained by cheating. I used a certain architecture to generate an adversarial split, and then I used the same architecture to test the effectiveness of that split. Perhaps I just made a split that’s hard for an MLP, but not hard in general. To truly demonstrate that the split is adversarial, I needed to test it on a different architecture. To this end, I tried a simple CNN that gets 99% accuracy on the standard test set. On the MLP-generated adversarial split, this model gets a whopping 87% test accuracy. So, it’s pretty clear that the adversarial split works across architectures.

There’s plenty of directions to take this from here. The following approaches could probably be used to generate even worse splits:

  • Use an ensemble of architectures to select the adversarial test set.
  • Use some kind of iterative process, continually making the split worse and worse.
  • Try some kind of clustering algorithm to make a test set that is a fundamentally different distribution than the training set.

Maybe this result doesn’t mean anything, but I still think it’s cool. It shows that the way you choose a train/test split can matter. It also shows that, perhaps, MNIST is not as uniform as people think.

The code for all of this can be found here.

Decision Trees as RL Policies

In supervised learning, there are very good “shallow” models like XGBoost and SVMs. These models can learn powerful classifiers without an artificial neuron in sight. So why, then, is modern reinforcement learning totally dominated by neural networks? My answer: no good reason. And now I want to show everyone that shallow architectures can do RL too.

Right now, using absolutely no feature engineering, I can train an ensemble of decision trees to play various video games from the raw pixels. The performance isn’t comparable to deep RL algorithms yet, and it may never be for vision-based tasks (for good reason!), but it’s fairly impressive nonetheless.

A tree ensemble playing Atari Pong

So how exactly do I train shallow models on RL tasks? You might have a few ideas, and so did I. Today, I’ll just be telling you about the one that actually worked.

The algorithm itself is so simple that I’m almost kind of embarrassed by my previous (failed) attempts at tree-based RL. Essentially, I use gradient boosting with gradients from a policy gradient estimator. I call the resulting algorithm policy gradient boosting. In practice, I use a slightly more complex algorithm (a tree-based variant of PPO), but there is probably plenty of room for simplification.

With policy gradient boosting, we build up an ensemble of trees in an additive fashion. For every batch of experience, we add a few more trees to our model, making minor adjustments in the direction of the policy gradient. After hundreds or even thousands of trees, we can end up with a pretty good policy.

Now that I’ve found an algorithm that works pretty well, I want to figure out better hyper-parameters for it. I doubt that tree-based PPO is the best (or even a good) technique, and I doubt that my regularization heuristic is very good either. Yet, even with these somewhat random choices, my models perform well on very difficult tasks! This has convinced me that shallow architectures could really disrupt the modern RL landscape, given the proper care and attention.

All the code for this project is in my treeagent repository, and there are some video demonstrations up on this YouTube playlist. If you’re interested, feel free to contribute to treeagent on Github or send me a PM on Twitter.