View Source Writing custom event handlers

Mix.install([
  {:axon, ">= 0.5.0"}
])
:ok

Writing custom event handlers

If you require functionality not offered by any of Axon's built-in event handlers, then you'll need to write a custom event handler. Custom event handlers are functions which accept loop state, perform some action, and then defer execution back to the main loop. For example, you can write custom loop handlers which visualize model outputs, communicate with an external Kino process, or simply halt the loop based on some criteria.

All event handlers must accept an %Axon.Loop.State{} struct and return a tuple of {control_term, state} where control_term is one of :continue, :halt_epoch, or :halt_loop and state is the updated loop state:

defmodule CustomEventHandler0 do
  alias Axon.Loop.State

  def my_weird_handler(%State{} = state) do
    IO.puts("My weird handler: fired")
    {:continue, state}
  end
end
{:module, CustomEventHandler0, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:my_weird_handler, 1}}

To register event handlers, you use Axon.Loop.handle/4:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.handle_event(:epoch_completed, &CustomEventHandler0.my_weird_handler/1)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {&CustomEventHandler0.my_weird_handler/1,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>},
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Axon will trigger your custom handler to run on the attached event:

train_data =
  Stream.repeatedly(fn ->
    {xs, _next_key} =
      :random.uniform(9999)
      |> Nx.Random.key()
      |> Nx.Random.normal(shape: {8, 1})

    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 50, loss: 0.0990703
My weird handler: fired
Epoch: 1, Batch: 50, loss: 0.0567622
My weird handler: fired
Epoch: 2, Batch: 50, loss: 0.0492784
My weird handler: fired
Epoch: 3, Batch: 50, loss: 0.0462587
My weird handler: fired
Epoch: 4, Batch: 50, loss: 0.0452806
My weird handler: fired
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.10819189250469208, 0.008151392452418804, -0.0318693183362484, 0.010302421636879444, 0.15788722038269043, 0.05119801685214043, 0.14268818497657776, -0.11528034508228302]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.4275593161582947, 0.40442031621932983, 0.7287659645080566, -0.7832129597663879, 0.3329123258590698, -0.5598123073577881, 0.8389336466789246, 0.3197469413280487]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.0671013742685318, 0.13561469316482544, 0.06218714639544487, 0.2104845941066742]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.4444102942943573, 0.4518184959888458, 0.45315614342689514, 0.35392478108406067],
        [0.008407601155340672, -0.6081852912902832, -0.05863206833600998, 0.14386630058288574],
        [-0.010219200514256954, -0.5528244376182556, 0.3754919469356537, -0.6242967247962952],
        [0.3531058132648468, -0.18348301947116852, -0.0019897441379725933, 0.41002658009529114],
        [0.676723062992096, -0.09349705278873444, 0.1101854145526886, 0.06494166702032089],
        [0.1534113883972168, 0.6402403116226196, 0.23490086197853088, -0.2196572870016098],
        [0.5835862755775452, -0.6581316590309143, -0.3047991394996643, -0.07485166192054749],
        [-0.6115342378616333, 0.3316897749900818, -0.3606548309326172, 0.3397740423679352]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.10111129283905029]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.7433153390884399],
        [-0.8213723301887512],
        [-0.44361063838005066],
        [-1.049617052078247]
      ]
    >
  }
}

You can use event handlers to early-stop a loop or loop epoch by returning a :halt_* control term. Halt control terms can be one of :halt_epoch or :halt_loop. :halt_epoch halts the current epoch and continues to the next. :halt_loop halts the loop altogether.

defmodule CustomEventHandler1 do
  alias Axon.Loop.State

  def always_halts(%State{} = state) do
    IO.puts("stopping loop")
    {:halt_loop, state}
  end
end
{:module, CustomEventHandler1, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:always_halts, 1}}

The loop will immediately stop executing and return the current state at the time it was halted:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.handle_event(:epoch_completed, &CustomEventHandler1.always_halts/1)
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 50, loss: 0.2201974
stopping loop
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.07676638662815094, -0.18689222633838654, 0.10066182911396027, -0.021994125097990036, 0.12006694823503494, -0.014219668693840504, 0.13600556552410126, -0.017512166872620583]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.5354958772659302, -0.216745987534523, -0.5694359540939331, 0.023495405912399292, 0.17701618373394012, 0.011712944135069847, 0.5289720892906189, 0.07360327988862991]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.0012482400052249432, 0.09300543367862701, 0.08570009469985962, -0.018982920795679092]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.3016211688518524, 0.31998082995414734, -0.3300730884075165, 0.24982869625091553],
        [0.03864569962024689, -0.44071364402770996, 0.6553062200546265, -0.5294798612594604],
        [0.25020459294319153, 0.7249991297721863, 0.15611837804317474, -0.5045580863952637],
        [-0.5500670075416565, 0.15677094459533691, -0.6531851291656494, -0.09289993345737457],
        [0.1618722379207611, 0.4479053020477295, 0.705923318862915, -0.3853490352630615],
        [-0.6752215623855591, 0.577272891998291, -0.1268012821674347, 0.6133111715316772],
        [0.5361366271972656, -0.2996085286140442, 0.28480708599090576, 0.47739118337631226],
        [-0.6443014144897461, -0.2866927981376648, 0.023463081568479538, -0.1491370052099228]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.0047520860098302364]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.3796459138393402],
        [-0.9757304191589355],
        [0.9530885815620422],
        [-0.05134368687868118]
      ]
    >
  }
}

Note that halting an epoch will fire a different event than completing an epoch. So if you implement a custom handler to halt the loop when an epoch completes, it will never fire if the epoch always halts prematurely:

defmodule CustomEventHandler2 do
  alias Axon.Loop.State

  def always_halts_epoch(%State{} = state) do
    IO.puts("\nstopping epoch")
    {:halt_epoch, state}
  end

  def always_halts_loop(%State{} = state) do
    IO.puts("stopping loop\n")
    {:halt_loop, state}
  end
end
{:module, CustomEventHandler2, <<70, 79, 82, 49, 0, 0, 8, ...>>, {:always_halts_loop, 1}}

If you run these handlers in conjunction, the loop will not terminate prematurely:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.handle_event(:iteration_completed, &CustomEventHandler2.always_halts_epoch/1)
|> Axon.Loop.handle_event(:epoch_completed, &CustomEventHandler2.always_halts_loop/1)
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 0, loss: 0.0000000
stopping epoch

stopping epoch

stopping epoch

stopping epoch

stopping epoch
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.009215549565851688, -0.005282022058963776, -0.0023747326340526342, 0.002623362001031637, 0.003890525083988905, 6.010813522152603e-4, -0.0024882694706320763, 0.0029246946796774864]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.3484582304954529, -0.39938971400260925, 0.03963512182235718, -0.3549930155277252, 0.09539157152175903, 0.5987873077392578, -0.23635399341583252, 0.01850329153239727]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.00194685033056885, 0.007812315598130226, 0.01710106059908867, 0.0080711729824543]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.6497661471366882, -0.3379145562648773, 0.3343344032764435, 0.4334254860877991],
        [-0.37884217500686646, -0.41724908351898193, -0.19513007998466492, -0.22494879364967346],
        [-0.42438197135925293, -0.40400123596191406, 0.5355109572410583, 0.4295356869697571],
        [0.15086597204208374, 0.30529624223709106, 0.002222923096269369, 0.32834741473197937],
        [-0.09336567670106888, 0.471781849861145, -0.06567475199699402, -0.4361487627029419],
        [0.23664812743663788, 0.13572633266448975, -0.13837064802646637, -0.09471122920513153],
        [0.6461064219474792, -0.2435072958469391, -0.04861235246062279, -0.1969985067844391],
        [0.17856749892234802, 0.41614532470703125, -0.06008348613977432, -0.3271574079990387]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.005317525006830692]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.07891849428415298],
        [0.32653072476387024],
        [-0.5885495543479919],
        [-0.2781771719455719]
      ]
    >
  }
}

You may access and update any portion of the loop state. Keep in mind that event handlers are not JIT-compiled, so you should be certain to manually JIT-compile any long-running or expensive operations.