View Source Classifying handwritten digits

  {:axon, "~> 0.3.0"},
  {:nx, "~> 0.4.0", override: true},
  {:exla, "~> 0.4.0"},
  {:req, "~> 0.3.1"}


This livebook will walk you through training a basic neural network using Axon, accelerated by the EXLA compiler. We'll be working on the MNIST dataset which is a dataset of handwritten digits with corresponding labels. The goal is to train a model that correctly classifies these handwritten digits with a single label [0-9].

Retrieving and exploring the dataset

The MNIST dataset is available for free online. Using Req we'll download both training images and training labels. Both train_images and train_labels are compressed binary data. Fortunately, Req takes care of the decompression for us.

You can read more about the format of the ubyte files here. Each file starts with a magic number and some metadata. We can use binary pattern matching to extract the information we want. In this case we extract the raw binary images and labels.

base_url = ""
%{body: train_images} = Req.get!(base_url <> "train-images-idx3-ubyte.gz")
%{body: train_labels} = Req.get!(base_url <> "train-labels-idx1-ubyte.gz")

<<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = train_images
<<_::32, n_labels::32, labels::binary>> = train_labels

We can easily read that binary data into a tensor using Nx.from_binary/2. Nx.from_binary/2 expects a raw binary and a data type. In this case, both images and labels are stored as unsigned 8-bit integers. We can start by parsing our images:

images =
  |> Nx.from_binary({:u, 8})
  |> Nx.reshape({n_images, 1, n_rows, n_cols}, names: [:images, :channels, :height, :width])
  |> Nx.divide(255)

Nx.from_binary/2 returns a flat tensor. Using Nx.reshape/3 we can manipulate this flat tensor into meaningful dimensions. Notice we also normalized the tensor by dividing the input data by 255. This squeezes the data between 0 and 1 which often leads to better behavior when training models. Now, let's see what these images look like:

images[[images: 0..4]] |> Nx.to_heatmap()

In the reshape operation above, we give each dimension of the tensor a name. This makes it much easier to do things like slicing, and helps make your code easier to understand. Here we slice the images dimension of the images tensor to obtain the first 5 training images. Then, we convert them to a heatmap for easy visualization.

It's common to train neural networks in batches (actually correctly called minibatches, but you'll see batch and minibatch used interchangeably). We can "batch" our images into batches of 32 like this:

images = Nx.to_batched(images, 32)

Now, we'll need to get our labels into batches as well, but first we need to one-hot encode the labels. One-hot encoding converts input data from labels such as 3, 5, 7, etc. into vectors of 0's and a single 1 at the correct labels index. As an example, a label of: 3 gets converted to: [0, 0, 0, 1, 0, 0, 0, 0, 0, 0].

targets =
  |> Nx.from_binary({:u, 8})
  |> Nx.new_axis(-1)
  |> Nx.equal(Nx.tensor(Enum.to_list(0..9)))
  |> Nx.to_batched(32)

Defining the model

Let's start by defining a simple model:

model =
  Axon.input("input", shape: {nil, 1, 28, 28})
  |> Axon.flatten()
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(10, activation: :softmax)

All Axon models start with an input layer to tell subsequent layers what shapes to expect. We then use Axon.flatten/2 which flattens the previous layer by squeezing all dimensions but the first dimension into a single dimension. Our model consists of 2 fully connected layers with 128 and 10 units respectively. The first layer uses :relu activation which returns max(0, input) element-wise. The final layer uses :softmax activation to return a probability distribution over the 10 labels [0 - 9].


In Axon we express the task of training using a declarative loop API. First, we need to specify a loss function and optimizer, there are many built-in variants to choose from. In this example, we'll use categorical cross-entropy and the Adam optimizer. We will also keep track of the accuracy metric. Finally, we run training loop passing our batched images and labels. We'll train for 10 epochs using the EXLA compiler.

params =
  |> Axon.Loop.trainer(:categorical_cross_entropy, :adam)
  |> Axon.Loop.metric(:accuracy, "Accuracy")
  |>, targets), %{}, epochs: 10, compiler: EXLA)


Now that we have the parameters from the training step, we can use them for predictions. For this the Axon.predict can be used.

first_batch =, 0)

output = Axon.predict(model, params, first_batch)

For each image, the model outputs probability distribution. This informs us how certain the model is about its prediction. Let's see the most probable digit for each image:

Nx.argmax(output, axis: 1)

If you look at the original images and you will see the predictions match the data!