View Source A Variational Autoencoder for MNIST

Mix.install([
  {:exla, "~> 0.4.0"},
  {:nx, "~> 0.4.0", override: true},
  {:axon, "~> 0.3.0"},
  {:req, "~> 0.3.1"},
  {:kino, "~> 0.7.0"},
  {:scidata, "~> 0.1.9"},
  {:stb_image, "~> 0.5.2"},
  {:kino_vega_lite, "~> 0.1.6"},
  {:vega_lite, "~> 0.1.6"},
  {:table_rex, "~> 3.1.1"}
])

alias VegaLite, as: Vl

# This speeds up all our `Nx` operations without having to use `defn`
Nx.global_default_backend(EXLA.Backend)

:ok

Introduction

In this notebook, we'll be building a variational autoencoder (VAE). This will help demonstrate splitting up models, defining custom layers and loss functions, using multiple outputs, and a few additional Kino tricks for training models.

This notebook builds on the denoising autoencoder example and turns the simple autoencoder into a variational one for the same dataset.

Training a simple autoencoder

This section will proceed without much explanation as most of it is extracted from denoising autoencoder example. If anything here doesn't make sense, take a look at that notebook for an explanation.

defmodule Data do
  @moduledoc """
  A module to hold useful data processing utilities,
  mostly extracted from the previous notebook
  """

  @doc """
  Converts the given image into a `Kino.Image`.

  `image` must be a single channel `Nx` tensor with pixel values between 0 and 1.
  `height` and `width` are the output size in pixels
  """
  def image_to_kino(image, height \\ 200, width \\ 200) do
    image
    |> Nx.multiply(255)
    |> Nx.as_type(:u8)
    |> Nx.transpose(axes: [:height, :width, :channels])
    |> StbImage.from_nx()
    |> StbImage.resize(height, width)
    |> StbImage.to_binary(:png)
    |> Kino.Image.new(:png)
  end

  @doc """
  Converts image data from `Scidata.MNIST` into an `Nx` tensor and normalizes it.
  """
  def preprocess_data(data) do
    {image_data, _labels} = data
    {images_binary, type, shape} = image_data

    images_binary
    |> Nx.from_binary(type)
    # Since pixels are organized row-wise, reshape into rows x columns
    |> Nx.reshape(shape, names: [:images, :channels, :height, :width])
    # Normalize the pixel values to be between 0 and 1
    |> Nx.divide(255)
  end

  @doc """
  Converts a tensor of images into random batches of paired images for model training
  """
  def prepare_training_data(images, batch_size) do
    Stream.flat_map([nil], fn nil ->
      images |> Nx.shuffle(axis: :images) |> Nx.to_batched(batch_size)
    end)
    |> Stream.map(fn batch -> {batch, batch} end)
  end
end
train_images = Data.preprocess_data(Scidata.FashionMNIST.download())
test_images = Data.preprocess_data(Scidata.FashionMNIST.download_test())

Kino.render(train_images[[images: 0]] |> Data.image_to_kino())
Kino.render(test_images[[images: 0]] |> Data.image_to_kino())

:ok

Now for our simple autoencoder model. We won't be using a denoising autoencoder here.

Note that we're giving each of the layers a name - the reason for this will be apparent later.

I'm also using a small custom layer to shift and scale the output of the sigmoid layer slightly so it can hit the 0 and 1 targets. I noticed the gradients tend to explode without this.

defmodule CustomLayer do
  import Nx.Defn

  def scaling_layer(%Axon{} = input, _opts \\ []) do
    Axon.layer(&scaling_layer_impl/2, [input])
  end

  defnp scaling_layer_impl(x, _opts \\ []) do
    x
    |> Nx.subtract(0.05)
    |> Nx.multiply(1.2)
  end
end
model =
  Axon.input("image", shape: {nil, 1, 28, 28})
  # This is now 28*28*1 = 784
  |> Axon.flatten()
  # The encoder
  |> Axon.dense(256, activation: :relu, name: "encoder_layer_1")
  |> Axon.dense(128, activation: :relu, name: "encoder_layer_2")
  |> Axon.dense(64, activation: :relu, name: "encoder_layer_3")
  # Bottleneck layer
  |> Axon.dense(10, activation: :relu, name: "bottleneck_layer")
  # The decoder
  |> Axon.dense(64, activation: :relu, name: "decoder_layer_1")
  |> Axon.dense(128, activation: :relu, name: "decoder_layer_2")
  |> Axon.dense(256, activation: :relu, name: "decoder_layer_3")
  |> Axon.dense(784, activation: :sigmoid, name: "decoder_layer_4")
  |> CustomLayer.scaling_layer()
  # Turn it back into a 28x28 single channel image
  |> Axon.reshape({:auto, 1, 28, 28})

# We can use Axon.Display to show us what each of the layers would look like
# assuming we send in a batch of 4 images
Axon.Display.as_table(model, Nx.template({4, 1, 28, 28}, :f32)) |> IO.puts()
batch_size = 128

train_data = Data.prepare_training_data(train_images, 128)
test_data = Data.prepare_training_data(test_images, 128)

{input_batch, target_batch} = Enum.at(train_data, 0)
Kino.render(input_batch[[images: 0]] |> Data.image_to_kino())
Kino.render(target_batch[[images: 0]] |> Data.image_to_kino())

:ok

When training, it can be useful to stop execution early - either when you see it's failing and you don't want to waste time waiting for the remaining epochs to finish, or if it's good enough and you want to start experimenting with it.

The kino_early_stop/1 function below is a handy handler to give us a Kino.Control.button that will stop the training loop when clicked.

We also have plot_losses/1 function to visualize our train and validation losses using VegaLite.

defmodule KinoAxon do
  @doc """
  Adds handler function which adds a frame with a "stop" button
  to the cell with the training loop.

  Clicking "stop" will halt the training loop.
  """
  def kino_early_stop(loop) do
    frame = Kino.Frame.new() |> Kino.render()
    stop_button = Kino.Control.button("stop")
    Kino.Frame.render(frame, stop_button)

    {:ok, button_agent} = Agent.start_link(fn -> nil end)

    stop_button
    |> Kino.Control.stream()
    |> Kino.listen(fn _event ->
      Agent.update(button_agent, fn _ -> :stop end)
    end)

    handler = fn state ->
      stop_state = Agent.get(button_agent, & &1)

      if stop_state == :stop do
        Agent.stop(button_agent)
        Kino.Frame.render(frame, "stopped")
        {:halt_loop, state}
      else
        {:continue, state}
      end
    end

    Axon.Loop.handle(loop, :iteration_completed, handler)
  end

  @doc """
  Plots the training and validation losses using Kino and VegaLite.

  This *must* come after `Axon.Loop.validate`.
  """
  def plot_losses(loop) do
    vl_widget =
      Vl.new(width: 600, height: 400)
      |> Vl.mark(:point, tooltip: true)
      |> Vl.encode_field(:x, "epoch", type: :ordinal)
      |> Vl.encode_field(:y, "loss", type: :quantitative)
      |> Vl.encode_field(:color, "dataset", type: :nominal)
      |> Kino.VegaLite.new()
      |> Kino.render()

    handler = fn state ->
      %Axon.Loop.State{metrics: metrics, epoch: epoch} = state
      loss = metrics["loss"] |> Nx.to_number()
      val_loss = metrics["validation_loss"] |> Nx.to_number()

      points = [
        %{epoch: epoch, loss: loss, dataset: "train"},
        %{epoch: epoch, loss: val_loss, dataset: "validation"}
      ]

      Kino.VegaLite.push_many(vl_widget, points)
      {:continue, state}
    end

    Axon.Loop.handle(loop, :epoch_completed, handler)
  end
end
# A helper function to display the input and output side by side
combined_input_output = fn params, image_index ->
  test_image = test_images[[images: image_index]]
  reconstructed_image = Axon.predict(model, params, test_image) |> Nx.squeeze(axes: [0])
  Nx.concatenate([test_image, reconstructed_image], axis: :width)
end

frame = Kino.Frame.new() |> Kino.render()

render_example_handler = fn state ->
  # state.step_state[:model_state] contains the model params when this event is fired
  params = state.step_state[:model_state]
  image_index = Enum.random(0..(Nx.axis_size(test_images, :images) - 1))
  image = combined_input_output.(params, image_index) |> Data.image_to_kino(200, 400)
  Kino.Frame.render(frame, image)
  Kino.Frame.append(frame, "Epoch: #{state.epoch}, Iteration: #{state.iteration}")
  {:continue, state}
end

params =
  model
  |> Axon.Loop.trainer(:mean_squared_error, Polaris.Optimizers.adamw(learning_rate: 0.001))
  |> KinoAxon.kino_early_stop()
  |> Axon.Loop.handle(:iteration_completed, render_example_handler, every: 450)
  |> Axon.Loop.validate(model, test_data)
  |> KinoAxon.plot_losses()
  |> Axon.Loop.run(train_data, %{}, epochs: 40, compiler: EXLA)

:ok

Splitting up the model

Cool! We now have the parameters for a trained, simple autoencoder. Our next step is to split up the model so we can use the encoder and decoder separately. By doing that, we'll be able to take an image and encode it to get the model's compressed image representation (the latent vector). We can then manipulate the latent vector and run the manipulated latent vector through the decoder to get a new image.

Let's start by defining the encoder and decoder separately as two different models.

encoder =
  Axon.input("image", shape: {nil, 1, 28, 28})
  # This is now 28*28*1 = 784
  |> Axon.flatten()
  # The encoder
  |> Axon.dense(256, activation: :relu, name: "encoder_layer_1")
  |> Axon.dense(128, activation: :relu, name: "encoder_layer_2")
  |> Axon.dense(64, activation: :relu, name: "encoder_layer_3")
  # Bottleneck layer
  |> Axon.dense(10, activation: :relu, name: "bottleneck_layer")

# The output from the encoder
decoder =
  Axon.input("latent", shape: {nil, 10})
  # The decoder
  |> Axon.dense(64, activation: :relu, name: "decoder_layer_1")
  |> Axon.dense(128, activation: :relu, name: "decoder_layer_2")
  |> Axon.dense(256, activation: :relu, name: "decoder_layer_3")
  |> Axon.dense(784, activation: :sigmoid, name: "decoder_layer_4")
  |> CustomLayer.scaling_layer()
  # Turn it back into a 28x28 single channel image
  |> Axon.reshape({:auto, 1, 28, 28})

Axon.Display.as_table(encoder, Nx.template({4, 1, 28, 28}, :f32)) |> IO.puts()
Axon.Display.as_table(decoder, Nx.template({4, 10}, :f32)) |> IO.puts()

We have the two models, but the problem is these are untrained models so we don't have the corresponding set of parameters. We'd like to use the parameters from the autoencoder we just trained and apply them to our split up models.

Let's first take a look at what params actually are:

params

Params are just a Map with the layer name as the key identifying which parameters to use. We can easily match up the layer names with the output from the Axon.Display.as_table/2 call for the autoencoder model.

So all we need to do is create a new Map that plucks out the right layers from our autoencoder params for each model and use that to run inference on our split up models.

Fortunately, since we gave each of the layers names, this requires no work at all - we can use the Map as it is since the layer names match up! Axon will ignore any extra keys so those won't be a problem.

Note that naming the layers wasn't required, if the layers didn't have names we would have some renaming to do to get the names to match between the models. But giving them names made it very convenient :)

Let's try encoding an image, printing the latent and then decoding the latent using our split up model to make sure it's working.

image = test_images[[images: 0]]

# Encode the image
latent = Axon.predict(encoder, params, image)
IO.inspect(latent, label: "Latent")
# Decode the image
reconstructed_image = Axon.predict(decoder, params, latent) |> Nx.squeeze(axes: [0])

combined_image = Nx.concatenate([image, reconstructed_image], axis: :width)
Data.image_to_kino(combined_image, 200, 400)

Perfect! Seems like the split up models are working as expected. Now let's try to generate some new images using our autoencoder. To do this, we'll manipulate the latent so it's slightly different from what the encoder gave us. Specifically, we'll try to interpolate between two images, showing 100 steps from our starting image to our final image.

num_steps = 100

# Get our latents, image at index 0 is our starting point
# index 1 is where we'll end
latents = Axon.predict(encoder, params, test_images[[images: 0..1]])
# Latents is a {2, 10} tensor
# The step we'll add to our latent to move it towards image[1]
step = Nx.subtract(latents[1], latents[0]) |> Nx.divide(num_steps)
# We can make a batch of all our new latents
new_latents = Nx.multiply(Nx.iota({num_steps + 1, 1}), step) |> Nx.add(latents[0])

reconstructed_images = Axon.predict(decoder, params, new_latents)

reconstructed_images =
  Nx.reshape(
    reconstructed_images,
    Nx.shape(reconstructed_images),
    names: [:images, :channels, :height, :width]
  )

Stream.interval(div(5000, num_steps))
|> Stream.take(num_steps + 1)
|> Kino.animate(fn i ->
  Data.image_to_kino(reconstructed_images[i])
end)

Cool! We have interpolation! But did you notice that some of the intermediate frames don't look fashionable at all? Autoencoders don't generally return good results for random vectors in their latent space. That's where a VAE can help.

Making it variational

In a VAE, instead of outputting a latent vector, our encoder will output a distribution. Essentially this means instead of 10 outputs we'll have 20. 10 of them will represent the mean and 10 will represent the log of the variance of the latent. We'll have to sample from this distribution to get our latent vector. Finally, we'll have to modify our loss function to also compute the KL Divergence between the latent distribution and a standard normal distribution (this acts as a regularizer of the latent space).

We'll start by defining our model:

defmodule Vae do
  import Nx.Defn

  @latent_features 10

  defp sampling_layer(%Axon{} = input, _opts \\ []) do
    Axon.layer(&sampling_layer_impl/2, [input], name: "sampling_layer", op_name: :sample)
  end

  defnp sampling_layer_impl(x, _opts \\ []) do
    mu = x[[0..-1//1, 0, 0..-1//1]]
    log_var = x[[0..-1//1, 1, 0..-1//1]]
    std_dev = Nx.exp(0.5 * log_var)
    eps = Nx.random_normal(std_dev)
    sample = mu + std_dev * eps
    Nx.stack([sample, mu, std_dev], axis: 1)
  end

  defp encoder_partial() do
    Axon.input("image", shape: {nil, 1, 28, 28})
    # This is now 28*28*1 = 784
    |> Axon.flatten()
    # The encoder
    |> Axon.dense(256, activation: :relu, name: "encoder_layer_1")
    |> Axon.dense(128, activation: :relu, name: "encoder_layer_2")
    |> Axon.dense(64, activation: :relu, name: "encoder_layer_3")
    # Bottleneck layer
    |> Axon.dense(@latent_features * 2, name: "bottleneck_layer")
    # Split up the mu and logvar
    |> Axon.reshape({:auto, 2, @latent_features})
    |> sampling_layer()
  end

  def encoder() do
    encoder_partial()
    # Grab only the sample (ie. the sampled latent)
    |> Axon.nx(fn x -> x[[0..-1//1, 0]] end)
  end

  def decoder(input_latent) do
    input_latent
    |> Axon.dense(64, activation: :relu, name: "decoder_layer_1")
    |> Axon.dense(128, activation: :relu, name: "decoder_layer_2")
    |> Axon.dense(256, activation: :relu, name: "decoder_layer_3")
    |> Axon.dense(784, activation: :sigmoid, name: "decoder_layer_4")
    |> CustomLayer.scaling_layer()
    # Turn it back into a 28x28 single channel image
    |> Axon.reshape({:auto, 1, 28, 28})
  end

  def autoencoder() do
    encoder_partial = encoder_partial()
    encoder = encoder()
    autoencoder = decoder(encoder)
    Axon.container(%{mu_sigma: encoder_partial, reconstruction: autoencoder})
  end
end

There's a few interesting things going on here. First, since our model has become more complex, we've used a module to keep it organized. We also built a custom layer to do the sampling and output the sampled latent vector as well as the distribution parameters (mu and sigma).

Finally, we need the distribution itself so we can calculate the KL Divergence in our loss function. To make the model output the distribution parameters (mu and sigma), we use Axon.container/1 to produce two outputs from our model instead of one. Now, instead of getting a tensor as an output, we'll get a map with the two tensors we need for our loss function.

Our loss function also has to be modified so be the sum of the KL divergence and MSE. Here's our custom loss function:

defmodule CustomLoss do
  import Nx.Defn

  defn loss(y_true, %{reconstruction: reconstruction, mu_sigma: mu_sigma}) do
    mu = mu_sigma[[0..-1//1, 1, 0..-1//1]]
    sigma = mu_sigma[[0..-1//1, 2, 0..-1//1]]
    kld = Nx.sum(-Nx.log(sigma) - 0.5 + Nx.multiply(sigma, sigma) + Nx.multiply(mu, mu))
    kld * 0.1 + Axon.Losses.mean_squared_error(y_true, reconstruction, reduction: :sum)
  end
end

With all our pieces ready, we can pretty much use the same training loop as we did earlier. The only modifications needed are to account for the fact that the model outputs a map with two values instead of a single tensor and telling the trainer to use our custom loss.

model = Vae.autoencoder()

# A helper function to display the input and output side by side
combined_input_output = fn params, image_index ->
  test_image = test_images[[images: image_index]]
  %{reconstruction: reconstructed_image} = Axon.predict(model, params, test_image)
  reconstructed_image = reconstructed_image |> Nx.squeeze(axes: [0])
  Nx.concatenate([test_image, reconstructed_image], axis: :width)
end

frame = Kino.Frame.new() |> Kino.render()

render_example_handler = fn state ->
  # state.step_state[:model_state] contains the model params when this event is fired
  params = state.step_state[:model_state]
  image_index = Enum.random(0..(Nx.axis_size(test_images, :images) - 1))
  image = combined_input_output.(params, image_index) |> Data.image_to_kino(200, 400)
  Kino.Frame.render(frame, image)
  Kino.Frame.append(frame, "Epoch: #{state.epoch}, Iteration: #{state.iteration}")
  {:continue, state}
end

params =
  model
  |> Axon.Loop.trainer(&CustomLoss.loss/2, Polaris.Optimizers.adam(learning_rate: 0.001))
  |> KinoAxon.kino_early_stop()
  |> Axon.Loop.handle(:epoch_completed, render_example_handler)
  |> Axon.Loop.validate(model, test_data)
  |> KinoAxon.plot_losses()
  |> Axon.Loop.run(train_data, %{}, epochs: 40, compiler: EXLA)

:ok

Finally, we can try our interpolation again:

num_steps = 100

# Get our latents, image at index 0 is our starting point
# index 1 is where we'll end
latents = Axon.predict(Vae.encoder(), params, test_images[[images: 0..1]])
# Latents is a {2, 10} tensor
# The step we'll add to our latent to move it towards image[1]
step = Nx.subtract(latents[1], latents[0]) |> Nx.divide(num_steps)
# We can make a batch of all our new latents
new_latents = Nx.multiply(Nx.iota({num_steps + 1, 1}), step) |> Nx.add(latents[0])

decoder = Axon.input("latent", shape: {nil, 10}) |> Vae.decoder()

reconstructed_images = Axon.predict(decoder, params, new_latents)

reconstructed_images =
  Nx.reshape(
    reconstructed_images,
    Nx.shape(reconstructed_images),
    names: [:images, :channels, :height, :width]
  )

Stream.interval(div(5000, num_steps))
|> Stream.take(num_steps + 1)
|> Kino.animate(fn i ->
  Data.image_to_kino(reconstructed_images[i])
end)

Did you notice the difference? Every step in our interpolation looks similar to items in our dataset! This is the benefit of the VAE: we can generate new items by using random latents. In contrast, in the simple autoencoder, for the most part only latents we got from our encoder were likely to produce sensible outputs.