View Source Training an Autoencoder on Fashion MNIST
Mix.install([
{:axon, "~> 0.3.0"},
{:nx, "~> 0.4.0", override: true},
{:exla, "~> 0.4.0"},
{:scidata, "~> 0.1.9"}
])
Nx.Defn.default_options(compiler: EXLA)
Introduction
An autoencoder is a deep learning model which consists of two parts: encoder and decoder. The encoder compresses high dimensional data into a low dimensional representation and feeds it to the decoder. The decoder tries to recreate the original data from the low dimensional representation. Autoencoders can be used in the following problems:
- Dimensionality reduction
- Noise reduction
- Generative models
- Data augmentation
Let's walk through a basic autoencoder implementation in Axon to get a better understanding of how they work in practice.
Downloading the data
To train and test how our model works, we use one of the most popular data sets: Fashion MNIST. It consists of small black and white images of clothes. Loading this data set is very simple with the help of Scidata
.
{image_data, _label_data} = Scidata.FashionMNIST.download()
{bin, type, shape} = image_data
We get the data in a raw format, but this is exactly the information we need to build an Nx tensor.
train_images =
bin
|> Nx.from_binary(type)
|> Nx.reshape(shape)
|> Nx.divide(255.0)
We also normalize pixel values into the range $[0, 1]$.
We can visualize one of the images by looking at the tensor heatmap:
Nx.to_heatmap(train_images[1])
Encoder and decoder
First we need to define the encoder and decoder. Both are one-layer neural networks.
In the encoder, we start by flattening the input, so we get from shape {batch_size, 1, 28, 28}
to {batch_size, 784}
and we pass the input into a dense layer. Our dense layer has only latent_dim
number of neurons. The latent_dim
(or the latent space) is a compressed representation of data. Remember, we want our encoder to compress the input data into a lower-dimensional representation, so we choose a latent_dim
which is less than the dimensionality of the input.
encoder = fn x, latent_dim ->
x
|> Axon.flatten()
|> Axon.dense(latent_dim, activation: :relu)
end
Next, we pass the output of the encoder to the decoder and try to reconstruct the compressed data into its original form. Since our original input had a dimensionality of 784, we use a dense layer with 784 neurons. Because our original data was normalized to have pixel values between 0 and 1, we use a :sigmoid
activation in our dense layer to squeeze output values between 0 and 1. Our original input shape was 28x28, so we use Axon.reshape
to convert the flattened representation of the outputs into an image with correct the width and height.
decoder = fn x ->
x
|> Axon.dense(784, activation: :sigmoid)
|> Axon.reshape({:batch, 1, 28, 28})
end
If we just bind the encoder and decoder sequentially, we'll get the desired model. This was pretty smooth, wasn't it?
model =
Axon.input("input", shape: {nil, 1, 28, 28})
|> encoder.(64)
|> decoder.()
Training the model
Finally, we can train the model. We'll use the :adam
and :mean_squared_error
loss with Axon.Loop.trainer
. Our loss function will measure the aggregate error between pixels of original images and the model's reconstructed images. We'll also :mean_absolute_error
using Axon.Loop.metric
. Axon.Loop.run
trains the model with the given training data.
batch_size = 32
epochs = 5
batched_images = Nx.to_batched(train_images, batch_size)
train_batches = Stream.zip(batched_images, batched_images)
params =
model
|> Axon.Loop.trainer(:mean_squared_error, :adam)
|> Axon.Loop.metric(:mean_absolute_error, "Error")
|> Axon.Loop.run(train_batches, %{}, epochs: epochs, compiler: EXLA)
Extra: losses
To better understand what is mean absolute error (MAE) and mean square error (MSE) let's go through an example.
# Error definitions for a single sample
mean_square_error = fn y_pred, y ->
y_pred
|> Nx.subtract(y)
|> Nx.power(2)
|> Nx.mean()
end
mean_absolute_error = fn y_pred, y ->
y_pred
|> Nx.subtract(y)
|> Nx.abs()
|> Nx.mean()
end
We will work with a sample image of a shoe, a slightly noised version of that image, and also an entirely different image from the dataset.
shoe_image = train_images[0]
noised_shoe_image = Nx.add(shoe_image, Nx.random_normal(shoe_image, 0.0, 0.05))
other_image = train_images[1]
:ok
For the same image both errors should be 0, because when we have two exact copies, there is no pixel difference.
{
mean_square_error.(shoe_image, shoe_image),
mean_absolute_error.(shoe_image, shoe_image)
}
Now the noised image:
{
mean_square_error.(shoe_image, noised_shoe_image),
mean_absolute_error.(shoe_image, noised_shoe_image)
}
And a different image:
{
mean_square_error.(shoe_image, other_image),
mean_absolute_error.(shoe_image, other_image)
}
As we can see, the noised image has a non-zero MSE and MAE but is much smaller than the error of two completely different pictures. In other words, both of these error types measure the level of similarity between images. A small error implies decent prediction values. On the other hand, a large error value suggests poor quality of predictions.
If you look at our implementation of MAE and MSE, you will notice that they are very similar. MAE and MSE can also be called the $L_1$ and $L_2$ loss respectively for the $L_1$ and $L_2$ norm. The $L_2$ loss (MSE) is typically preferred because it's a smoother function whereas $L_1$ is often difficult to optimize with stochastic gradient descent (SGD).
Inference
Now, let's see how our model is doing! We will compare a sample image before and after compression.
sample_image = train_images[0..0//1]
compressed_image = Axon.predict(model, params, sample_image, compiler: EXLA)
sample_image
|> Nx.to_heatmap()
|> IO.inspect(label: "Original")
compressed_image
|> Nx.to_heatmap()
|> IO.inspect(label: "Compressed")
:ok
As we can see, the generated image is similar to the input image. The only difference between them is the absence of a sign in the middle of the second shoe. The model treated the sign as noise and bled this into the plain shoe.