View Source Complex models
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true},
{:kino, "~> 0.7.0"}
])
:ok
creating-more-complex-models
Creating more complex models
Not all models you'd want to create fit cleanly in the sequential paradigm. Some models require a more flexible API. Fortunately, because Axon models are just Elixir data structures, you can manipulate them and decompose architectures as you would any other Elixir program:
input = Axon.input("data")
x1 = input |> Axon.dense(32)
x2 = input |> Axon.dense(64) |> Axon.relu() |> Axon.dense(32)
out = Axon.add(x1, x2)
#Axon<
inputs: %{"data" => nil}
outputs: "add_0"
nodes: 7
>
In the snippet above, your model branches input
into x1
and x2
. Each branch performs a different set of transformations; however, at the end the branches are merged with an Axon.add/3
. You might sometimes see layers like Axon.add/3
called combinators. Really they're just layers that operate on multiple Axon models at once - typically to merge some branches together.
out
represents your final Axon model.
If you visualize this model, you can see the full effect of the branching in this model:
template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
6["dense_0 (:dense) {2, 32}"];
9["dense_1 (:dense) {2, 64}"];
10["relu_0 (:relu) {2, 64}"];
13["dense_2 (:dense) {2, 32}"];
14["container_0 (:container) {{2, 32}, {2, 32}}"];
15["add_0 (:add) {2, 32}"];
14 --> 15;
13 --> 14;
6 --> 14;
10 --> 13;
9 --> 10;
3 --> 9;
3 --> 6;
And you can use Axon.build/2
on out
as you would any other Axon model:
{init_fn, predict_fn} = Axon.build(out)
{#Function<135.51955502/2 in Nx.Defn.Compiler.fun/2>,
#Function<135.51955502/2 in Nx.Defn.Compiler.fun/2>}
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
#Nx.Tensor<
f32[2][32]
[
[-3.4256787300109863, -0.866683840751648, -0.2629307508468628, 3.2555718421936035, 2.2740533351898193, 3.0403499603271484, -2.7904915809631348, 3.4799132347106934, -4.16396951675415, -4.545778274536133, 3.146249532699585, -3.0786540508270264, 3.4500746726989746, 1.1419837474822998, -0.7993628978729248, 2.3798861503601074, 4.787802696228027, 1.290929913520813, 1.8274409770965576, -1.5016275644302368, 3.441028118133545, -1.8077948093414307, 0.25549376010894775, -2.555987596511841, -4.643674850463867, 2.164360523223877, -0.30402517318725586, -2.54134464263916, -2.699089527130127, 4.074007511138916, -0.7711544036865234, -3.988246202468872],
[-11.235082626342773, -1.5991168022155762, -4.076810836791992, 11.091293334960938, 4.669280052185059, 12.756690979003906, -1.4954360723495483, 4.8143310546875, -14.211947441101074, -11.360504150390625, 6.239661693572998, -0.9994411468505859, 8.645132064819336, -0.5422897338867188, -1.4019453525543213, 9.633858680725098, 10.077424049377441, -0.3623824119567871, ...]
]
>
As your architectures grow in complexity, you might find yourself reaching for better abstractions to organize your model creation code. For example, PyTorch models are often organized into nn.Module
. The equivalent of an nn.Module
in Axon is a regular Elixir function. If you're translating models from PyTorch to Axon, it's natural to create one Elixir function per nn.Module
.
You should write your models as you would write any other Elixir code - you don't need to worry about any framework specific constructs:
defmodule MyModel do
def model() do
Axon.input("data")
|> conv_block()
|> Axon.flatten()
|> dense_block()
|> dense_block()
|> Axon.dense(1)
end
defp conv_block(input) do
residual = input
x = input |> Axon.conv(3, padding: :same) |> Axon.mish()
x
|> Axon.add(residual)
|> Axon.max_pool(kernel_size: {2, 2})
end
defp dense_block(input) do
input |> Axon.dense(32) |> Axon.relu()
end
end
{:module, MyModel, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:dense_block, 1}}
model = MyModel.model()
#Axon<
inputs: %{"data" => nil}
outputs: "dense_2"
nodes: 12
>
template = Nx.template({1, 28, 28, 3}, :f32)
Axon.Display.as_graph(model, template)
graph TD;
16[/"data (:input) {1, 28, 28, 3}"/];
19["conv_0 (:conv) {1, 28, 28, 3}"];
20["mish_0 (:mish) {1, 28, 28, 3}"];
21["container_0 (:container) {{1, 28, 28, 3}, {1, 28, 28, 3}}"];
22["add_0 (:add) {1, 28, 28, 3}"];
23["max_pool_0 (:max_pool) {1, 14, 14, 3}"];
24["flatten_0 (:flatten) {1, 588}"];
27["dense_0 (:dense) {1, 32}"];
28["relu_0 (:relu) {1, 32}"];
31["dense_1 (:dense) {1, 32}"];
32["relu_1 (:relu) {1, 32}"];
35["dense_2 (:dense) {1, 1}"];
32 --> 35;
31 --> 32;
28 --> 31;
27 --> 28;
24 --> 27;
23 --> 24;
22 --> 23;
21 --> 22;
16 --> 21;
20 --> 21;
19 --> 20;
16 --> 19;