View Source Custom models, loss functions, and optimizers

Mix.install([
  {:axon, github: "elixir-nx/axon"},
  {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}
])
:ok

Using custom models in training loops

In the Your first training loop, you learned how to declare a supervised training loop using Axon.Loop.trainer/3 with a model, loss function, and optimizer. Your overall model and loop declaration looked something like this:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)

This example uses an %Axon{} struct to represent your model to train, and atoms to represent your loss function and optimizer. Some of your problems will require a bit more flexibility than this example affords. Fortunately, Axon.Loop.trainer/3 is designed for flexibility.

For example, if your model cannot be cleanly represented as an %Axon{} model, you can instead opt instead to define custom initialization and forward functions to pass to Axon.Loop.trainer/3. Actually, Axon.Loop.trainer/3 is doing this for you under the hood - the ability to pass an %Axon{} struct directly is just a convenience:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

lowered_model = {init_fn, predict_fn} = Axon.build(model)

loop = Axon.Loop.trainer(lowered_model, :mean_squared_error, :sgd)
#Axon.Loop<
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  metrics: %{
    "loss" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>,
     #Function<6.20267452/2 in Axon.Loop.build_loss_fn/1>}
  },
  ...
>

Notice that Axon.Loop.trainer/3 handles the "lowered" form of an Axon model without issue. When you pass an %Axon{} struct, the trainer factory converts it to a lowered representation for you. With this construct, you can build custom models entirely with Nx defn, or readily mix your Axon models into custom workflows without worrying about compatibility with the Axon.Loop API:

defmodule CustomModel do
  import Nx.Defn

  defn custom_predict_fn(model_predict_fn, params, input) do
    %{prediction: preds} = out = model_predict_fn.(params, input)
    %{out | prediction: Nx.cos(preds)}
  end
end
{:module, CustomModel, <<70, 79, 82, 49, 0, 0, 9, ...>>, {:custom_predict_fn, 3}}
train_data =
  Stream.repeatedly(fn ->
    xs = Nx.random_normal({8, 1})
    ys = Nx.sin(xs)
    {xs, ys}
  end)

{init_fn, predict_fn} = Axon.build(model, mode: :train)
custom_predict_fn = &CustomModel.custom_predict_fn(predict_fn, &1, &2)

loop = Axon.Loop.trainer({init_fn, custom_predict_fn}, :mean_squared_error, :sgd)

Axon.Loop.run(loop, train_data, %{}, iterations: 500)
Epoch: 0, Batch: 500, loss: 0.3053460
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.06573846191167831, 0.37533989548683167, -0.014221129938960075, -0.0056641618721187115, -0.013241665437817574, -0.04930500313639641, 0.03238297998905182, 0.019304191693663597]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.3132522702217102, -0.9284062385559082, 0.5041953921318054, 0.09051526337862015, 0.003381401300430298, -0.22686156630516052, 0.506594181060791, 0.46744370460510254]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.008441010490059853, 0.0, 0.5370790958404541, 0.03584281727671623]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.3442431688308716, -0.33131587505340576, -0.03751888871192932, -0.5497395396232605],
        [-0.4568001925945282, -0.5024663805961609, 0.8712142109870911, -0.13484779000282288],
        [0.7310590744018555, -0.34318023920059204, 0.3977772295475006, -0.6045383214950562],
        [-0.5255699157714844, -0.2829623818397522, -0.45367464423179626, -0.157784566283226],
        [-0.47948920726776123, 0.2930692136287689, -0.3784458339214325, -0.69244384765625],
        [0.7052943706512451, 0.015830136835575104, -0.02979498915374279, 0.6160839796066284],
        [0.3201732933521271, -0.1367085874080658, -0.17100055515766144, 0.7335636019706726],
        [-0.2825513482093811, -0.424674928188324, -0.3110836148262024, 0.46001508831977844]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.6889857649803162]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.7191283106803894],
        [-0.4222411513328552],
        [1.122635006904602],
        [-0.7385509014129639]
      ]
    >
  }
}

Using custom loss functions in training loops

Just as Axon.Loop.trainer/3 allows more flexibility with models, it also supports more flexible loss functions. In most cases, you can get away with using one of Axon's built-in loss functions by specifying an atom. Atoms map directly to a loss-function defined in Axon.Losses. Under the hood, Axon.Loop.trainer/3 is doing something like:

loss_fn = &apply(Axon.Losses, loss_atom, [&1, &2])

Rather than pass an atom, you can pass your own custom arity-2 function to Axon.Loop.trainer/3. This arises most often in cases where you want to control some parameters of the loss function, such as the batch-level reduction:

loss_fn = &Axon.Losses.mean_squared_error(&1, &2, reduction: :sum)

loop = Axon.Loop.trainer(model, loss_fn, :sgd)
#Axon.Loop<
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  metrics: %{
    "loss" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>,
     #Function<41.3316493/2 in :erl_eval.expr/6>}
  },
  ...
>

You can also define your own custom loss functions, so long as they match the following spec:

loss(
  y_true :: tensor[batch, ...] | container(tensor),
  y_preds :: tensor[batch, ...] | container(tensor)
  ) :: scalar

This is useful for constructing loss functions when dealing with multi-output scenarios. For example, it's very easy to construct a custom loss function which is a weighted average of several loss functions on multiple inputs:

train_data =
  Stream.repeatedly(fn ->
    xs = Nx.random_normal({8, 1})
    y1 = Nx.sin(xs)
    y2 = Nx.cos(xs)
    {xs, {y1, y2}}
  end)

shared =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()

y1 = Axon.dense(shared, 1)
y2 = Axon.dense(shared, 1)

model = Axon.container({y1, y2})

custom_loss_fn = fn {y_true1, y_true2}, {y_pred1, y_pred2} ->
  loss1 = Axon.Losses.mean_squared_error(y_true1, y_pred1, reduction: :mean)
  loss2 = Axon.Losses.mean_squared_error(y_true2, y_pred2, reduction: :mean)

  loss1
  |> Nx.multiply(0.4)
  |> Nx.add(Nx.multiply(loss2, 0.6))
end

model
|> Axon.Loop.trainer(custom_loss_fn, :sgd)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.1098235
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.07738334685564041, 0.04548311233520508, 0.049238916486501694, 0.38714033365249634, -0.030310271307826042, -0.07575170695781708, 0.02918776497244835, 0.15639683604240417]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.5250527858734131, 0.9252119660377502, -0.7720071077346802, 0.3685735762119293, -0.15688209235668182, -0.41163918375968933, 0.7827479839324951, 0.07295594364404678]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.012770675122737885, 0.6008449792861938, 0.29370757937431335, -0.05354489013552666]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.08783119916915894, 0.4296257495880127, 0.07153885811567307, -0.6921477317810059],
        [0.15848888456821442, -0.4663836658000946, 0.7126847505569458, 0.0693722814321518],
        [-0.24852830171585083, -0.7588720321655273, -0.5033655166625977, 0.6524038314819336],
        [0.2933746874332428, 0.6656989455223083, -0.046741705387830734, 0.44998466968536377],
        [0.17215801775455475, -0.3072860836982727, 0.2046997845172882, -0.7001357078552246],
        [0.6354788541793823, -0.12706635892391205, -0.18666459619998932, -0.26693975925445557],
        [-0.3737913966178894, -0.07344938814640045, 0.22658668458461761, -0.37110695242881775],
        [0.01989569514989853, 0.39410898089408875, -0.30496707558631897, -0.4945743680000305]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.5888826251029968]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [1.0239059925079346],
        [0.25252565741539],
        [0.8877795338630676],
        [-0.13882321119308472]
      ]
    >
  },
  "dense_3" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.2557465434074402]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.6269392371177673],
        [1.1281259059906006],
        [-0.503214418888092],
        [-0.5435869693756104]
      ]
    >
  }
}

Using custom optimizers in training loops

As you might expect, it's also possible to customize the optimizer passed to Axon.Loop.trainer/3. If you read the Polaris.Updates documentation, you'll learn that optimizers are actually represented as the tuple {init_fn, update_fn} where init_fn initializes optimizer state from model state and update_fn scales gradients from optimizer state, gradients, and model state.

You likely won't have to implement a custom optimizer; however, you should know how to construct optimizers with different hyperparameters and how to apply different modifiers to different optimizers to customize the optimization process.

When you specify an optimizer as an atom in Axon.Loop.trainer/3, it maps directly to an optimizer declared in Polaris.Optimizers. You can instead opt to declare your optimizer directly. This is most useful for controlling things like the learning rate and various optimizer hyperparameters:

train_data =
  Stream.repeatedly(fn ->
    xs = Nx.random_normal({8, 1})
    ys = Nx.sin(xs)
    {xs, ys}
  end)

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

optimizer = {_init_optimizer_fn, _update_fn} = Polaris.Optimizers.sgd(learning_rate: 1.0e-3)

model
|> Axon.Loop.trainer(:mean_squared_error, optimizer)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.0992607
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.06136200204491615, -0.08278193324804306, -0.07280997931957245, 0.08740464597940445, 0.08663233369588852, -0.06915996968746185, 0.03753892332315445, 0.06512840837240219]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.622833251953125, 0.24778570234775543, 0.4959430694580078, -0.604946494102478, -0.31578049063682556, 0.09977878630161285, 0.776294469833374, 0.5804685950279236]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.012786266393959522, 0.01057625561952591, 0.10597240924835205, 0.13692162930965424]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.46233609318733215, -0.7435348033905029, -0.10738609731197357, 0.09911829978227615],
        [0.5295257568359375, 0.48769527673721313, -0.23950818181037903, -0.26084062457084656],
        [-0.5117107033729553, 0.2039143443107605, -0.12630638480186462, -0.41089773178100586],
        [-0.6043668985366821, 0.3961969316005707, 0.5120400190353394, -0.6773409247398376],
        [0.22123000025749207, 0.7197521924972534, 0.2679356038570404, -0.12402179092168808],
        [0.4830038249492645, 0.3629038631916046, 0.49994897842407227, -0.25865232944488525],
        [0.29824453592300415, 0.29333528876304626, -0.05371938645839691, 0.5230391621589661],
        [0.5483304262161255, 0.08283360302448273, -0.6959219574928284, 0.6471460461616516]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.07759959995746613]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.036170706152915955],
        [-0.5362256765365601],
        [-0.6853286027908325],
        [0.6693617701530457]
      ]
    >
  }
}