View Source Complex models

Mix.install([
  {:axon, ">= 0.5.0"},
  {:kino, ">= 0.9.0"}
])
:ok

creating-more-complex-models

Creating more complex models

Not all models you'd want to create fit cleanly in the sequential paradigm. Some models require a more flexible API. Fortunately, because Axon models are just Elixir data structures, you can manipulate them and decompose architectures as you would any other Elixir program:

input = Axon.input("data")

x1 = input |> Axon.dense(32)
x2 = input |> Axon.dense(64) |> Axon.relu() |> Axon.dense(32)

out = Axon.add(x1, x2)
#Axon<
  inputs: %{"data" => nil}
  outputs: "add_0"
  nodes: 7
>

In the snippet above, your model branches input into x1 and x2. Each branch performs a different set of transformations; however, at the end the branches are merged with an Axon.add/3. You might sometimes see layers like Axon.add/3 called combinators. Really they're just layers that operate on multiple Axon models at once - typically to merge some branches together.

out represents your final Axon model.

If you visualize this model, you can see the full effect of the branching in this model:

template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
graph TD;
3[/"data (:input) {2, 8}"/];
4["dense_0 (:dense) {2, 32}"];
5["dense_1 (:dense) {2, 64}"];
6["relu_0 (:relu) {2, 64}"];
7["dense_2 (:dense) {2, 32}"];
8["container_0 (:container) {{2, 32}, {2, 32}}"];
9["add_0 (:add) {2, 32}"];
8 --> 9;
7 --> 8;
4 --> 8;
6 --> 7;
5 --> 6;
3 --> 5;
3 --> 4;

And you can use Axon.build/2 on out as you would any other Axon model:

{init_fn, predict_fn} = Axon.build(out)
{#Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>,
 #Function<135.109794929/2 in Nx.Defn.Compiler.fun/2>}
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
#Nx.Tensor<
  f32[2][32]
  [
    [-4.283246040344238, 1.8983498811721802, 3.697357654571533, -4.720174789428711, 4.1636152267456055, 1.001131534576416, -0.7027540802955627, -3.7821826934814453, 0.027841567993164062, 9.267499923706055, 3.33616304397583, -1.5465859174728394, 8.983413696289062, 3.7445120811462402, 2.2405576705932617, -3.61336350440979, -1.7320983409881592, 0.5740477442741394, -0.22006472945213318, -0.1806044578552246, 1.1092393398284912, -0.29313594102859497, -0.41948509216308594, 3.526411533355713, -0.9127179384231567, 1.8373844623565674, 1.1746022701263428, -0.6885149478912354, -1.4326229095458984, -1.3498257398605347, -5.803186416625977, 1.5204020738601685],
    [-15.615742683410645, 6.555544853210449, 7.033155918121338, -12.33556842803955, 14.105436325073242, -4.230871200561523, 5.985136032104492, -8.445676803588867, 5.383096694946289, 23.413570404052734, 0.8907639980316162, -1.400709629058838, 19.19326400756836, 13.784171104431152, 9.641424179077148, -8.407038688659668, -5.688483238220215, 4.383636474609375, ...]
  ]
>

As your architectures grow in complexity, you might find yourself reaching for better abstractions to organize your model creation code. For example, PyTorch models are often organized into nn.Module. The equivalent of an nn.Module in Axon is a regular Elixir function. If you're translating models from PyTorch to Axon, it's natural to create one Elixir function per nn.Module.

You should write your models as you would write any other Elixir code - you don't need to worry about any framework specific constructs:

defmodule MyModel do
  def model() do
    Axon.input("data")
    |> conv_block()
    |> Axon.flatten()
    |> dense_block()
    |> dense_block()
    |> Axon.dense(1)
  end

  defp conv_block(input) do
    residual = input

    x = input |> Axon.conv(3, padding: :same) |> Axon.mish()

    x
    |> Axon.add(residual)
    |> Axon.max_pool(kernel_size: {2, 2})
  end

  defp dense_block(input) do
    input |> Axon.dense(32) |> Axon.relu()
  end
end
{:module, MyModel, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:dense_block, 1}}
model = MyModel.model()
#Axon<
  inputs: %{"data" => nil}
  outputs: "dense_2"
  nodes: 12
>
template = Nx.template({1, 28, 28, 3}, :f32)
Axon.Display.as_graph(model, template)
graph TD;
10[/"data (:input) {1, 28, 28, 3}"/];
11["conv_0 (:conv) {1, 28, 28, 3}"];
12["mish_0 (:mish) {1, 28, 28, 3}"];
13["container_0 (:container) {{1, 28, 28, 3}, {1, 28, 28, 3}}"];
14["add_0 (:add) {1, 28, 28, 3}"];
15["max_pool_0 (:max_pool) {1, 14, 14, 3}"];
16["flatten_0 (:flatten) {1, 588}"];
17["dense_0 (:dense) {1, 32}"];
18["relu_0 (:relu) {1, 32}"];
19["dense_1 (:dense) {1, 32}"];
20["relu_1 (:relu) {1, 32}"];
21["dense_2 (:dense) {1, 1}"];
20 --> 21;
19 --> 20;
18 --> 19;
17 --> 18;
16 --> 17;
15 --> 16;
14 --> 15;
13 --> 14;
10 --> 13;
12 --> 13;
11 --> 12;
10 --> 11;