View Source Multi-input / multi-output models

Mix.install([
  {:axon, github: "elixir-nx/axon"},
  {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true},
  {:kino, "~> 0.7.0"}
])
:ok

creating-multi-input-models

Creating multi-input models

Sometimes your application necessitates the use of multiple inputs. To use multiple inputs in an Axon model, you just need to declare multiple inputs in your graph:

input_1 = Axon.input("input_1")
input_2 = Axon.input("input_2")

out = Axon.add(input_1, input_2)
#Axon<
  inputs: %{"input_1" => nil, "input_2" => nil}
  outputs: "add_0"
  nodes: 4
>

Notice when you inspect the model, it tells you what your models inputs are up front. You can also get metadata about your model inputs programmatically with Axon.get_inputs/1:

Axon.get_inputs(out)
%{"input_1" => nil, "input_2" => nil}

Each input is uniquely named, so you can pass inputs by-name into inspection and execution functions with a map:

inputs = %{
  "input_1" => Nx.template({2, 8}, :f32),
  "input_2" => Nx.template({2, 8}, :f32)
}

Axon.Display.as_graph(out, inputs)
graph TD;
3[/"input_1 (:input) {2, 8}"/];
4[/"input_2 (:input) {2, 8}"/];
5["container_0 (:container) {{2, 8}, {2, 8}}"];
6["add_0 (:add) {2, 8}"];
5 --> 6;
4 --> 5;
3 --> 5;
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(inputs, %{})
%{}
inputs = %{
  "input_1" => Nx.iota({2, 8}, type: :f32),
  "input_2" => Nx.iota({2, 8}, type: :f32)
}

predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][8]
  [
    [0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0],
    [16.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0]
  ]
>

If you forget a required input, Axon will raise:

predict_fn.(params, %{"input_1" => Nx.iota({2, 8}, type: :f32)})

creating-multi-output-models

Creating multi-output models

Depending on your application, you might also want your model to have multiple outputs. You can achieve this by using Axon.container/2 to wrap multiple nodes into any supported Nx container:

inp = Axon.input("data")

x1 = inp |> Axon.dense(32) |> Axon.relu()
x2 = inp |> Axon.dense(64) |> Axon.relu()

out = Axon.container({x1, x2})
#Axon<
  inputs: %{"data" => nil}
  outputs: "container_0"
  nodes: 6
>
template = Nx.template({2, 8}, :f32)
Axon.Display.as_graph(out, template)
graph TD;
7[/"data (:input) {2, 8}"/];
10["dense_0 (:dense) {2, 32}"];
11["relu_0 (:relu) {2, 32}"];
14["dense_1 (:dense) {2, 64}"];
15["relu_1 (:relu) {2, 64}"];
16["container_0 (:container) {{2, 32}, {2, 64}}"];
15 --> 16;
11 --> 16;
14 --> 15;
7 --> 14;
10 --> 11;
7 --> 10;

When executed, containers will return a data structure which matches their input structure:

{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
{#Nx.Tensor<
   f32[2][32]
   [
     [0.0, 0.0, 3.111135482788086, 0.48920655250549316, 0.0, 0.5125713348388672, 0.0, 0.0, 1.482532262802124, 0.0, 0.0, 0.0, 0.0, 3.103637933731079, 0.46897295117378235, 2.6465413570404053, 2.837477445602417, 0.6159781217575073, 1.3220927715301514, 0.0, 0.24302834272384644, 3.4662821292877197, 0.40560781955718994, 0.0, 0.0, 0.2682836055755615, 3.5352964401245117, 0.0, 0.6591103672981262, 2.5643503665924072, 0.0, 0.0],
     [0.0, 0.0, 4.642599105834961, 0.0, 0.0, 1.8978865146636963, 2.2522430419921875, 0.0, 1.2110804319381714, 2.5524141788482666, 0.0, 0.742849588394165, 0.0, 8.30776596069336, 5.09386682510376, 4.69991397857666, 5.195588111877441, ...]
   ]
 >,
 #Nx.Tensor<
   f32[2][64]
   [
     [0.0, 0.0, 0.7948622107505798, 0.0, 0.0, 0.0, 0.0, 0.0, 2.3980231285095215, 5.2512712478637695, 1.5820361375808716, 0.0, 2.6624603271484375, 0.0, 0.0, 0.0, 1.6954007148742676, 0.017102837562561035, 0.7754535675048828, 0.0, 1.891753911972046, 0.0, 2.7824556827545166, 0.0, 0.5906356573104858, 0.0, 0.0, 1.288651466369629, 0.6939071416854858, 0.8427785038948059, 1.5664646625518799, 0.38097164034843445, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3193289637565613, 0.0, 0.0, 0.35316526889801025, 0.0, 1.2567038536071777, 0.7732977867126465, 0.16440902650356293, 0.0, 1.9872947931289673, ...],
     ...
   ]
 >}

You can output maps as well:

out = Axon.container(%{x1: x1, x2: x2})
#Axon<
  inputs: %{"data" => nil}
  outputs: "container_0"
  nodes: 6
>
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
%{
  x1: #Nx.Tensor<
    f32[2][32]
    [
      [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.8718442916870117, 0.0, 1.813383936882019, 0.0, 0.0, 0.0, 0.0, 3.0636630058288574, 0.0, 1.1350113153457642, 1.7888737916946411, 0.0658932775259018, 0.0, 0.4498137831687927, 1.1311852931976318, 3.2784717082977295, 0.0, 2.4505443572998047, 3.346879005432129, 0.0, 0.0, 2.614570140838623, 0.0, 0.0, 0.8967163562774658, 0.0],
      [0.0, 0.0, 0.0, 1.9045438766479492, 0.0, 0.0, 7.110898971557617, 0.09859625995159149, 8.149545669555664, 0.0, 0.0, 0.0, 0.0, 4.178244113922119, 0.0, 3.8360297679901123, 6.177351474761963, ...]
    ]
  >,
  x2: #Nx.Tensor<
    f32[2][64]
    [
      [0.41670602560043335, 0.0, 0.0, 0.0, 1.338260531425476, 0.0, 0.5181264877319336, 1.1024510860443115, 0.0, 0.0, 1.485485553741455, 0.0, 0.0, 1.9365136623382568, 0.0, 0.0, 0.0, 0.0, 2.6925604343414307, 0.6202171444892883, 0.0, 0.08886899054050446, 0.0, 1.3045244216918945, 0.0, 0.0545249879360199, 0.0, 1.2294358015060425, 0.0, 0.0, 0.670710563659668, 0.0, 4.161868572235107, 1.880513072013855, 2.6189277172088623, 0.5702207684516907, 0.0, 1.953904151916504, 0.0, 0.0, 1.370330572128296, 0.17245425283908844, 1.9922431707382202, 2.6845364570617676, 0.3711611032485962, 0.7940037250518799, 0.0, 2.12975811958313, ...],
      ...
    ]
  >
}

Containers even support arbitrary nesting:

out = Axon.container({%{x1: {x1, x2}, x2: %{x1: x1, x2: {x2}}}})
#Axon<
  inputs: %{"data" => nil}
  outputs: "container_0"
  nodes: 6
>
{init_fn, predict_fn} = Axon.build(out)
params = init_fn.(template, %{})
predict_fn.(params, Nx.iota({2, 8}, type: :f32))
{%{
   x1: {#Nx.Tensor<
      f32[2][32]
      [
        [3.9104199409484863, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.051666498184204, 0.0, 1.086042881011963, 0.6107193827629089, 0.5136545896530151, 2.7927842140197754, 0.0, 0.0, 0.0, 0.0, 2.472961902618408, 0.13712915778160095, 0.49807000160217285, 1.7868735790252686, 5.796293258666992, 0.0, 0.0, 4.727283477783203, 0.0, 0.0, 2.129516363143921, 0.0, 0.0],
        [11.746908187866211, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.840534687042236, 0.0, 0.0, 4.103122711181641, 1.0597835779190063, 8.971627235412598, ...]
      ]
    >,
    #Nx.Tensor<
      f32[2][64]
      [
        [0.951026439666748, 0.0, 0.6895619034767151, 0.12973949313163757, 3.0561492443084717, 0.0, 0.21812109649181366, 0.0, 0.6377829313278198, 0.0, 0.0, 0.0, 0.0, 1.6837494373321533, 0.0, 0.0, 0.0, 1.3907173871994019, 0.0, 0.0, 0.21352148056030273, 0.0, 1.2145031690597534, 0.0, 3.080430507659912, 0.0, 3.9572620391845703, 2.3347463607788086, 0.5280991196632385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.616438627243042, 0.0, 1.1335082054138184, 2.228783369064331, 0.0, 1.0927692651748657, 0.0, 0.0, 0.0, 0.0, 2.7719650268554688, ...],
        ...
      ]
    >},
   x2: %{
     x1: #Nx.Tensor<
       f32[2][32]
       [
         [3.9104199409484863, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.051666498184204, 0.0, 1.086042881011963, 0.6107193827629089, 0.5136545896530151, 2.7927842140197754, 0.0, 0.0, 0.0, 0.0, 2.472961902618408, 0.13712915778160095, 0.49807000160217285, 1.7868735790252686, 5.796293258666992, 0.0, 0.0, 4.727283477783203, 0.0, 0.0, 2.129516363143921, 0.0, 0.0],
         [11.746908187866211, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.840534687042236, 0.0, 0.0, 4.103122711181641, 1.0597835779190063, ...]
       ]
     >,
     x2: {#Nx.Tensor<
        f32[2][64]
        [
          [0.951026439666748, 0.0, 0.6895619034767151, 0.12973949313163757, 3.0561492443084717, 0.0, 0.21812109649181366, 0.0, 0.6377829313278198, 0.0, 0.0, 0.0, 0.0, 1.6837494373321533, 0.0, 0.0, 0.0, 1.3907173871994019, 0.0, 0.0, 0.21352148056030273, 0.0, 1.2145031690597534, 0.0, 3.080430507659912, 0.0, 3.9572620391845703, 2.3347463607788086, 0.5280991196632385, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.616438627243042, 0.0, 1.1335082054138184, 2.228783369064331, 0.0, 1.0927692651748657, 0.0, 0.0, 0.0, ...],
          ...
        ]
      >}
   }
 }}