View Source Custom layers
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true},
{:kino, "~> 0.7.0"}
])
:ok
creating-custom-layers
Creating custom layers
While Axon has a plethora of built-in layers, more than likely you'll run into a case where you need something not provided by the framework. In these instances, you can use custom layers.
To Axon, layers are really just defn
implementations with special Axon inputs. Every layer in Axon (including the built-in layers), are implemented with the Axon.layer/3
function. The API of Axon.layer/3
intentionally mirrors the API of Kernel.apply/2
. To declare a custom layer you need 2 things:
- A
defn
implementation - Inputs
The defn
implementation looks like any other defn
you'd write; however, it must always account for additional opts
as an argument:
defmodule CustomLayers do
import Nx.Defn
defn my_layer(input, opts \\ []) do
opts = keyword!(opts, mode: :train, alpha: 1.0)
input
|> Nx.sin()
|> Nx.multiply(opts[:alpha])
end
end
{:module, CustomLayers, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:my_layer, 2}}
Regardless of the options you configure your layer to accept, the defn
implementation will always receive a :mode
option indicating whether or not the model is running in training or inference mode. You can customize the behavior of your layer depending on the mode.
With an implementation defined, you need only to call Axon.layer/3
to apply our custom layer to an Axon input:
input = Axon.input("data")
out = Axon.layer(&CustomLayers.my_layer/2, [input])
#Axon<
inputs: %{"data" => nil}
outputs: "custom_0"
nodes: 2
>
Now you can inspect and execute your model as normal:
template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
4["custom_0 (:custom) {2, 8}"];
3 --> 4;
Notice that by default custom layers render with a default operation marked as :custom
. This can make it difficult to determine which layer is which during inspection. You can control the rendering by passing :op_name
to Axon.layer/3
:
out = Axon.layer(&CustomLayers.my_layer/2, [input], op_name: :my_layer)
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
5["my_layer_0 (:my_layer) {2, 8}"];
3 --> 5;
You can also control the name of your layer via the :name
option. All other options are forwarded to the layer implementation function:
out =
Axon.layer(&CustomLayers.my_layer/2, [input],
name: "layer",
op_name: :my_layer,
alpha: 2.0
)
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
6["layer (:my_layer) {2, 8}"];
3 --> 6;
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
%{}
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
#Nx.Tensor<
f32[2][8]
[
[0.0, 1.6829419136047363, 1.8185948133468628, 0.28224000334739685, -1.513604998588562, -1.9178485870361328, -0.558830976486206, 1.3139731884002686],
[1.978716492652893, 0.8242369890213013, -1.0880422592163086, -1.9999804496765137, -1.073145866394043, 0.8403340578079224, 1.9812147617340088, 1.3005757331848145]
]
>
Notice that this model does not have any trainable parameters because none of the layers have trainable parameters. You can introduce trainable parameters by passing inputs created with Axon.param/3
to Axon.layer/3
. For example, you can modify your original custom layer to take an additional trainable parameter:
defmodule CustomLayers do
import Nx.Defn
defn my_layer(input, alpha, _opts \\ []) do
input
|> Nx.sin()
|> Nx.multiply(alpha)
end
end
{:module, CustomLayers, <<70, 79, 82, 49, 0, 0, 11, ...>>, {:my_layer, 3}}
And then construct the layer with a regular Axon input and a trainable parameter:
alpha = Axon.param("alpha", fn _ -> {} end)
out = Axon.layer(&CustomLayers.my_layer/3, [input, alpha], op_name: :my_layer)
#Axon<
inputs: %{"data" => nil}
outputs: "my_layer_0"
nodes: 2
>
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
%{
"my_layer_0" => %{
"alpha" => #Nx.Tensor<
f32
1.194254994392395
>
}
}
Notice how your model now initializes with a trainable parameter "alpha"
for your custom layer. Each parameter requires a unique (per-layer) string name and a function which determines the parameter's shape from the layer's input shapes.
If you plan on re-using custom layers in many locations, it's recommended that you wrap them in an Elixir function as an interface:
defmodule CustomLayers do
import Nx.Defn
def my_layer(%Axon{} = input, opts \\ []) do
opts = Keyword.validate!(opts, [:name])
alpha = Axon.param("alpha", fn _ -> {} end)
Axon.layer(&my_layer_impl/3, [input, alpha], name: opts[:name], op_name: :my_layer)
end
defnp my_layer_impl(input, alpha, _opts \\ []) do
input
|> Nx.sin()
|> Nx.multiply(alpha)
end
end
{:module, CustomLayers, <<70, 79, 82, 49, 0, 0, 13, ...>>, {:my_layer_impl, 3}}
out =
input
|> CustomLayers.my_layer()
|> CustomLayers.my_layer()
|> Axon.dense(1)
#Axon<
inputs: %{"data" => nil}
outputs: "dense_0"
nodes: 4
>
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
10["my_layer_0 (:my_layer) {2, 8}"];
12["my_layer_1 (:my_layer) {2, 8}"];
15["dense_0 (:dense) {2, 1}"];
12 --> 15;
10 --> 12;
3 --> 10;