View Source Custom layers

Mix.install([
  {:axon, ">= 0.5.0"},
  {:kino, ">= 0.9.0"}
])
:ok

Creating custom layers

While Axon has a plethora of built-in layers, more than likely you'll run into a case where you need something not provided by the framework. In these instances, you can use custom layers.

To Axon, layers are really just defn implementations with special Axon inputs. Every layer in Axon (including the built-in layers), are implemented with the Axon.layer/3 function. The API of Axon.layer/3 intentionally mirrors the API of Kernel.apply/2. To declare a custom layer you need 2 things:

  1. A defn implementation
  2. Inputs

The defn implementation looks like any other defn you'd write; however, it must always account for additional opts as an argument:

defmodule CustomLayers0 do
  import Nx.Defn

  defn my_layer(input, opts \\ []) do
    opts = keyword!(opts, mode: :train, alpha: 1.0)

    input
    |> Nx.sin()
    |> Nx.multiply(opts[:alpha])
  end
end
{:module, CustomLayers0, <<70, 79, 82, 49, 0, 0, 10, ...>>, true}

Regardless of the options you configure your layer to accept, the defn implementation will always receive a :mode option indicating whether or not the model is running in training or inference mode. You can customize the behavior of your layer depending on the mode.

With an implementation defined, you need only to call Axon.layer/3 to apply our custom layer to an Axon input:

input = Axon.input("data")

out = Axon.layer(&CustomLayers0.my_layer/2, [input])
#Axon<
  inputs: %{"data" => nil}
  outputs: "custom_0"
  nodes: 2
>

Now you can inspect and execute your model as normal:

template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
4["custom_0 (:custom) {2, 8}"];
3 --> 4;

Notice that by default custom layers render with a default operation marked as :custom. This can make it difficult to determine which layer is which during inspection. You can control the rendering by passing :op_name to Axon.layer/3:

out = Axon.layer(&CustomLayers0.my_layer/2, [input], op_name: :my_layer)

Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
5["my_layer_0 (:my_layer) {2, 8}"];
3 --> 5;

You can also control the name of your layer via the :name option. All other options are forwarded to the layer implementation function:

out =
  Axon.layer(&CustomLayers0.my_layer/2, [input],
    name: "layer",
    op_name: :my_layer,
    alpha: 2.0
  )

Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
6["layer (:my_layer) {2, 8}"];
3 --> 6;
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
%{}
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
#Nx.Tensor<
  f32[2][8]
  [
    [0.0, 1.6829419136047363, 1.8185948133468628, 0.28224000334739685, -1.513604998588562, -1.9178485870361328, -0.558830976486206, 1.3139731884002686],
    [1.978716492652893, 0.8242369890213013, -1.0880422592163086, -1.9999804496765137, -1.073145866394043, 0.8403340578079224, 1.9812147617340088, 1.3005757331848145]
  ]
>

Notice that this model does not have any trainable parameters because none of the layers have trainable parameters. You can introduce trainable parameters by passing inputs created with Axon.param/3 to Axon.layer/3. For example, you can modify your original custom layer to take an additional trainable parameter:

defmodule CustomLayers1 do
  import Nx.Defn

  defn my_layer(input, alpha, _opts \\ []) do
    input
    |> Nx.sin()
    |> Nx.multiply(alpha)
  end
end
{:module, CustomLayers1, <<70, 79, 82, 49, 0, 0, 10, ...>>, true}

And then construct the layer with a regular Axon input and a trainable parameter:

alpha = Axon.param("alpha", fn _ -> {} end)

out = Axon.layer(&CustomLayers1.my_layer/3, [input, alpha], op_name: :my_layer)
#Axon<
  inputs: %{"data" => nil}
  outputs: "my_layer_0"
  nodes: 2
>
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
%{
  "my_layer_0" => %{
    "alpha" => #Nx.Tensor<
      f32
      -1.2601861953735352
    >
  }
}

Notice how your model now initializes with a trainable parameter "alpha" for your custom layer. Each parameter requires a unique (per-layer) string name and a function which determines the parameter's shape from the layer's input shapes.

If you plan on re-using custom layers in many locations, it's recommended that you wrap them in an Elixir function as an interface:

defmodule CustomLayers2 do
  import Nx.Defn

  def my_layer(%Axon{} = input, opts \\ []) do
    opts = Keyword.validate!(opts, [:name])
    alpha = Axon.param("alpha", fn _ -> {} end)

    Axon.layer(&my_layer_impl/3, [input, alpha], name: opts[:name], op_name: :my_layer)
  end

  defnp my_layer_impl(input, alpha, _opts \\ []) do
    input
    |> Nx.sin()
    |> Nx.multiply(alpha)
  end
end
{:module, CustomLayers2, <<70, 79, 82, 49, 0, 0, 12, ...>>, true}
out =
  input
  |> CustomLayers2.my_layer()
  |> CustomLayers2.my_layer()
  |> Axon.dense(1)
#Axon<
  inputs: %{"data" => nil}
  outputs: "dense_0"
  nodes: 4
>
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
8["my_layer_0 (:my_layer) {2, 8}"];
9["my_layer_1 (:my_layer) {2, 8}"];
10["dense_0 (:dense) {2, 1}"];
9 --> 10;
8 --> 9;
3 --> 8;