View Source Training an Autoencoder on Fashion MNIST

Mix.install([
  {:axon, "~> 0.3.0"},
  {:nx, "~> 0.4.0", override: true},
  {:exla, "~> 0.4.0"},
  {:scidata, "~> 0.1.9"}
])

Nx.Defn.default_options(compiler: EXLA)

Introduction

An autoencoder is a deep learning model which consists of two parts: encoder and decoder. The encoder compresses high dimensional data into a low dimensional representation and feeds it to the decoder. The decoder tries to recreate the original data from the low dimensional representation. Autoencoders can be used in the following problems:

  • Dimensionality reduction
  • Noise reduction
  • Generative models
  • Data augmentation

Let's walk through a basic autoencoder implementation in Axon to get a better understanding of how they work in practice.

Downloading the data

To train and test how our model works, we use one of the most popular data sets: Fashion MNIST. It consists of small black and white images of clothes. Loading this data set is very simple with the help of Scidata.

{image_data, _label_data} = Scidata.FashionMNIST.download()
{bin, type, shape} = image_data

We get the data in a raw format, but this is exactly the information we need to build an Nx tensor.

train_images =
  bin
  |> Nx.from_binary(type)
  |> Nx.reshape(shape)
  |> Nx.divide(255.0)

We also normalize pixel values into the range $[0, 1]$.

We can visualize one of the images by looking at the tensor heatmap:

Nx.to_heatmap(train_images[1])

Encoder and decoder

First we need to define the encoder and decoder. Both are one-layer neural networks.

In the encoder, we start by flattening the input, so we get from shape {batch_size, 1, 28, 28} to {batch_size, 784} and we pass the input into a dense layer. Our dense layer has only latent_dim number of neurons. The latent_dim (or the latent space) is a compressed representation of data. Remember, we want our encoder to compress the input data into a lower-dimensional representation, so we choose a latent_dim which is less than the dimensionality of the input.

encoder = fn x, latent_dim ->
  x
  |> Axon.flatten()
  |> Axon.dense(latent_dim, activation: :relu)
end

Next, we pass the output of the encoder to the decoder and try to reconstruct the compressed data into its original form. Since our original input had a dimensionality of 784, we use a dense layer with 784 neurons. Because our original data was normalized to have pixel values between 0 and 1, we use a :sigmoid activation in our dense layer to squeeze output values between 0 and 1. Our original input shape was 28x28, so we use Axon.reshape to convert the flattened representation of the outputs into an image with correct the width and height.

decoder = fn x ->
  x
  |> Axon.dense(784, activation: :sigmoid)
  |> Axon.reshape({:batch, 1, 28, 28})
end

If we just bind the encoder and decoder sequentially, we'll get the desired model. This was pretty smooth, wasn't it?

model =
  Axon.input("input", shape: {nil, 1, 28, 28})
  |> encoder.(64)
  |> decoder.()

Training the model

Finally, we can train the model. We'll use the :adam and :mean_squared_error loss with Axon.Loop.trainer. Our loss function will measure the aggregate error between pixels of original images and the model's reconstructed images. We'll also :mean_absolute_error using Axon.Loop.metric. Axon.Loop.run trains the model with the given training data.

batch_size = 32
epochs = 5

batched_images = Nx.to_batched(train_images, batch_size)
train_batches = Stream.zip(batched_images, batched_images)

params =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :adam)
  |> Axon.Loop.metric(:mean_absolute_error, "Error")
  |> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)

Extra: losses

To better understand what is mean absolute error (MAE) and mean square error (MSE) let's go through an example.

# Error definitions for a single sample

mean_square_error = fn y_pred, y ->
  y_pred
  |> Nx.subtract(y)
  |> Nx.power(2)
  |> Nx.mean()
end

mean_absolute_error = fn y_pred, y ->
  y_pred
  |> Nx.subtract(y)
  |> Nx.abs()
  |> Nx.mean()
end

We will work with a sample image of a shoe, a slightly noised version of that image, and also an entirely different image from the dataset.

shoe_image = train_images[0]
noised_shoe_image = Nx.add(shoe_image, Nx.random_normal(shoe_image, 0.0, 0.05))
other_image = train_images[1]
:ok

For the same image both errors should be 0, because when we have two exact copies, there is no pixel difference.

{
  mean_square_error.(shoe_image, shoe_image),
  mean_absolute_error.(shoe_image, shoe_image)
}

Now the noised image:

{
  mean_square_error.(shoe_image, noised_shoe_image),
  mean_absolute_error.(shoe_image, noised_shoe_image)
}

And a different image:

{
  mean_square_error.(shoe_image, other_image),
  mean_absolute_error.(shoe_image, other_image)
}

As we can see, the noised image has a non-zero MSE and MAE but is much smaller than the error of two completely different pictures. In other words, both of these error types measure the level of similarity between images. A small error implies decent prediction values. On the other hand, a large error value suggests poor quality of predictions.

If you look at our implementation of MAE and MSE, you will notice that they are very similar. MAE and MSE can also be called the $L_1$ and $L_2$ loss respectively for the $L_1$ and $L_2$ norm. The $L_2$ loss (MSE) is typically preferred because it's a smoother function whereas $L_1$ is often difficult to optimize with stochastic gradient descent (SGD).

Inference

Now, let's see how our model is doing! We will compare a sample image before and after compression.

sample_image = train_images[0..0//1]
compressed_image = Axon.predict(model, params, sample_image, compiler: EXLA)

sample_image
|> Nx.to_heatmap()
|> IO.inspect(label: "Original")

compressed_image
|> Nx.to_heatmap()
|> IO.inspect(label: "Compressed")

:ok

As we can see, the generated image is similar to the input image. The only difference between them is the absence of a sign in the middle of the second shoe. The model treated the sign as noise and bled this into the plain shoe.