View Source MNIST Denoising Autoencoder using Kino for visualization

  {:exla, "~> 0.4.0"},
  {:nx, "~> 0.4.0", override: true},
  {:axon, "~> 0.3.0"},
  {:req, "~> 0.3.1"},
  {:kino, "~> 0.7.0"},
  {:scidata, "~> 0.1.9"},
  {:stb_image, "~> 0.5.2"},
  {:table_rex, "~> 3.1.1"}



The goal of this notebook is to build a Denoising Autoencoder from scratch using Livebook. This notebook is based on Training an Autoencoder on Fashion MNIST, but includes some tips on using Livebook to train the model and using Kino (Livebook's interactive widget library) to play with and visualize our results.


Data loading

An autoencoder learns to recreate data it's seen in the dataset. For this notebook, we're going to try something simple: generating images of digits using the MNIST digit recognition dataset.

Following along with the Fashion MNIST Autoencoder example, we'll use Scidata to download the MNIST dataset and then preprocess the data.

# We're not going to use the labels so we'll ignore them
{train_images, _train_labels} =
{train_images_binary, type, shape} = train_images

The shape tells us we have 60,000 images with a single channel of size 28x28.

According to the MNIST website:

Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

Let's preprocess and normalize the data accordingly.

train_images =
  |> Nx.from_binary(type)
  # Since pixels are organized row-wise, reshape into rows x columns
  |> Nx.reshape(shape, names: [:images, :channels, :height, :width])
  # Normalize the pixel values to be between 0 and 1
  |> Nx.divide(255)
# Make sure they look like numbers
train_images[[images: 0..2]] |> Nx.to_heatmap()

That looks right! Let's repeat the process for the test set.

{test_images, _train_labels} = Scidata.MNIST.download_test()
{test_images_binary, type, shape} = test_images

test_images =
  |> Nx.from_binary(type)
  # Since pixels are organized row-wise, reshape into rows x columns
  |> Nx.reshape(shape, names: [:images, :channels, :height, :width])
  # Normalize the pixel values to be between 0 and 1
  |> Nx.divide(255)

test_images[[images: 0..2]] |> Nx.to_heatmap()


Building the model

An autoencoder is a a network that has the same sized input as output, with a "bottleneck" layer in the middle with far fewer parameters than the input. Its goal is to force the output to reconstruct the input. The bottleneck layer forces the network to learn a compressed representation of the input space.

A denoising autoencoder is a small tweak on an autoencoder that takes a corrupted input (often corrupted by adding noise or zeroing out pixels) and reconstructs the original input, removing the noise in the process.

The part of the autoencoder that takes the input and compresses it into the bottleneck layer is called the encoder and the part that takes the compressed representation and reconstructs the input is called the decoder. Usually the decoder mirrors the encoder.

MNIST is a pretty easy dataset, so we're going to try a fairly small autoencoder.

The input image has size 784 (28 rows 28 cols 1 pixel). We'll set up the encoder to turn that into 256 features, then 128, 64, and then 10 features for the bottleneck layer. The decoder will do the reverse, take the 10 features and go to 64, 128, 256 and 784. I'll use fully-connected (dense) layers.


The model

model =
  Axon.input("image", shape: {nil, 1, 28, 28})
  # This is now 28*28*1 = 784
  |> Axon.flatten()
  # The encoder
  |> Axon.dense(256, activation: :relu)
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(64, activation: :relu)
  # Bottleneck layer
  |> Axon.dense(10, activation: :relu)
  # The decoder
  |> Axon.dense(64, activation: :relu)
  |> Axon.dense(128, activation: :relu)
  |> Axon.dense(256, activation: :relu)
  |> Axon.dense(784, activation: :sigmoid)
  # Turn it back into a 28x28 single channel image
  |> Axon.reshape({:auto, 1, 28, 28})

# We can use Axon.Display to show us what each of the layers would look like
# assuming we send in a batch of 4 images
Axon.Display.as_table(model, Nx.template({4, 1, 28, 28}, :f32)) |> IO.puts()

Checking our understanding, since the layers are all dense layers, the number of parameters should be input_features * output_features parameters for the weights + output_features parameters for the biases for each layer.

This should match the Total Parameters output from Axon.Display (486298 parameters)

# encoder
encoder_parameters = 784 * 256 + 256 + (256 * 128 + 128) + (128 * 64 + 64) + (64 * 10 + 10)
decoder_parameters = 10 * 64 + 64 + (64 * 128 + 128) + (128 * 256 + 256) + (256 * 784 + 784)
total_parameters = encoder_parameters + decoder_parameters



With the model set up, we can now try to train the model. We'll use MSE loss to compare our reconstruction with the original

We'll create the training input by turning our image list into batches of size 128 and then using the same image as both the input and the target. However, the input image will have noise added to it that the autoencoder will have to remove.

For validation data, we'll use the test set and look at how the autoencoder does at reconstructing the test set to make sure we're not overfitting

The function below adds some noise to the image by adding the image with gaussian noise scaled by a noise factor. We then have to make sure the pixel values are still within the 0..1.0 range.

We have to define this function using defn so that Nx can optimize it. If we don't do this, adding noise will take a really long time, making our training loop very slow. See Nx.defn for more details. defn can only be used in a module so we'll define a little module to contain it.

defmodule Noiser do
  import Nx.Defn

  @noise_factor 0.4

  defn add_noise(images) do
    |> Nx.multiply(Nx.random_normal(images))
    |> Nx.add(images)
    |> Nx.clip(0.0, 1.0)

add_noise = Nx.Defn.jit(&Noiser.add_noise/1, compiler: EXLA)
batch_size = 128

# The original image which is the target the network will trying to match
batched_train_images =
  |> Nx.to_batched(batch_size)

batched_noisy_train_images =
  |> Nx.to_batched(batch_size)
  # goes after to_batched so the noise is different every time

# The noisy image is the input to the network
# and the original image is the target it's trying to match
train_data =, batched_train_images)

batched_test_images =
  |> Nx.to_batched(batch_size)

batched_noisy_test_images =
  |> Nx.to_batched(batch_size)

test_data =, batched_test_images)

Let's see what an element of the input and target look like

{input_batch, target_batch} =, 0)
{Nx.to_heatmap(input_batch[images: 0]), Nx.to_heatmap(target_batch[images: 0])}

Looks right (and tricky). Let's see how the model does.

params =
  |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001))
  |> Axon.Loop.validate(model, test_data)
  |>, %{}, epochs: 20, compiler: EXLA)


Now that we have a model that theoretically has learned something, we'll see what it's learned by running it on some images from the test set. We'll use Kino to allow us to select the image from the test set to run the model against. To avoid losing the params that took a while to train, we'll create another branch so we can experiment with the params and stop execution when needed without having to retrain.



A note on branching

By default, everything in Livebook runs sequentially in a single process. Stopping a running cell aborts that process and consequently all its state is lost. A branching section copies everything from its parent and runs in a separate process. Thanks to this isolation, when we stop a cell in a branching section, only the state within that section is gone.

Since we just spent a bunch of time training the model and don't want to lose that memory state as we continue to experiment, we create a branching section. This does add some memory overhead, but it's worth it so we can experiment without fear!

To use Kino to give us an interactive tool to evaluate the model, we'll create a Kino.Frame that we can dynamically update. We'll also create a form using Kino.Control to allow the user to select which image from the test set they'd like to evaluate the model on. Finally enables us to respond to changes in the user's selection when the user clicks the "Render" button.

We can use Nx.concatenate to stack the images side by side for a prettier output.

form =
      test_image_index: Kino.Input.number("Test Image Index", default: 0)
    submit: "Render"


|> Kino.animate(fn %{data: %{test_image_index: image_index}} ->
  test_image = test_images[[images: image_index]] |> add_noise.()

  reconstructed_image =
    |> Axon.predict(params, test_image)
    # Get rid of the batch dimension
    |> Nx.squeeze(axes: [0])

  combined_image = Nx.concatenate([test_image, reconstructed_image], axis: :width)

That looks pretty good!

Note we used Kino.animate/2 which runs asynchronously so we don't block execution of the rest of the notebook.


A better training loop

Note that we branch from the "Building a model" section since we only need the model definition for this section and not the previously trained model.

It'd be nice to see how the model improves as it trains. In this section (also a branch since I plan to experiment and don't want to lose the execution state) we'll improve the training loop to use Kino to show us how it's doing.

Axon.Loop.handle gives us a hook into various points of the training loop. We'll can use it with the :iteration_completed event to get a copy of the state of the params after some number of completed iterations of the training loop. By using those params to render an image in the test set, we can get a live view of the autoencoder learning to reconstruct its inputs.

# A helper function to display the input and output side by side
combined_input_output = fn params, image_index ->
  test_image = test_images[[images: image_index]] |> add_noise.()
  reconstructed_image = Axon.predict(model, params, test_image) |> Nx.squeeze(axes: [0])
  Nx.concatenate([test_image, reconstructed_image], axis: :width)

Nx.to_heatmap(combined_input_output.(params, 0))

It'd also be nice to have a prettier version of the output. Let's convert the heatmap to a png to make that happen.

image_to_kino = fn image ->
  |> Nx.multiply(255)
  |> Nx.as_type(:u8)
  |> Nx.transpose(axes: [:height, :width, :channels])
  |> StbImage.from_nx()
  |> StbImage.resize(200, 400)
  |> StbImage.to_binary(:png)

image_to_kino.(combined_input_output.(params, 0))

Much nicer!

Once again we'll use Kino.Frame for dynamically updating output:

frame = |> Kino.render()

render_example_handler = fn state ->
  Kino.Frame.append(frame, "Epoch: #{state.epoch}, Iteration: #{state.iteration}")
  # state.step_state[:model_state] contains the model params when this event is fired
  params = state.step_state[:model_state]
  image_index = Enum.random(0..(Nx.axis_size(test_images, :images) - 1))
  image = combined_input_output.(params, image_index) |> image_to_kino.()
  Kino.Frame.append(frame, image)
  {:continue, state}

params =
  |> Axon.Loop.trainer(:mean_squared_error, Axon.Optimizers.adamw(0.001))
  |> Axon.Loop.handle(:iteration_completed, render_example_handler, every: 450)
  |> Axon.Loop.validate(model, test_data)
  |>, %{}, epochs: 20, compiler: EXLA)


Awesome! We have a working denoising autoencoder that we can visualize getting better in 20 epochs!