View Source Writing custom event handlers

Mix.install([
  {:axon, github: "elixir-nx/axon"},
  {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}
])
:ok

writing-custom-event-handlers

Writing custom event handlers

If you require functionality not offered by any of Axon's built-in event handlers, then you'll need to write a custom event handler. Custom event handlers are functions which accept loop state, perform some action, and then defer execution back to the main loop. For example, you can write custom loop handlers which visualize model outputs, communicate with an external Kino process, or simply halt the loop based on some criteria.

All event handlers must accept an %Axon.Loop.State{} struct and return a tuple of {control_term, state} where control_term is one of :continue, :halt_epoch, or :halt_loop and state is the updated loop state:

defmodule CustomEventHandler do
  alias Axon.Loop.State

  def my_weird_handler(%State{} = state) do
    IO.puts("My weird handler: fired")
    {:continue, state}
  end
end
{:module, CustomEventHandler, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:my_weird_handler, 1}}

To register event handlers, you use Axon.Loop.handle/4:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.handle(:epoch_completed, &CustomEventHandler.my_weird_handler/1)
#Axon.Loop<
  handlers: %{
    completed: [],
    epoch_completed: [
      {&CustomEventHandler.my_weird_handler/1,
       #Function<5.33119226/1 in Axon.Loop.build_filter_fn/1>},
      {#Function<23.33119226/1 in Axon.Loop.log/5>,
       #Function<5.33119226/1 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<23.33119226/1 in Axon.Loop.log/5>,
       #Function<3.33119226/1 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  metrics: %{
    "loss" => {#Function<12.46375131/3 in Axon.Metrics.running_average/1>,
     #Function<6.33119226/2 in Axon.Loop.build_loss_fn/1>}
  },
  ...
>

Axon will trigger your custom handler to run on the attached event:

train_data =
  Stream.repeatedly(fn ->
    xs = Nx.random_normal({8, 1})
    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 100, loss: 0.1905403
My weird handler: fired
Epoch: 1, Batch: 100, loss: 0.1478554
My weird handler: fired
Epoch: 2, Batch: 100, loss: 0.1184390
My weird handler: fired
Epoch: 3, Batch: 100, loss: 0.0983292
My weird handler: fired
Epoch: 4, Batch: 100, loss: 0.0845697
My weird handler: fired
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.014659373089671135, 0.08941870182752609, -0.09661660343408585, 0.2650177478790283, -0.06400775164365768, -0.07953602075576782, 0.22094617784023285, -0.014790073968470097]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.3581556975841522, 0.38828182220458984, -0.3311854302883148, -0.4059808552265167, 0.6334917545318604, 0.17008493840694427, -0.5630434155464172, 0.3790667653083801]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.3047839403152466, -0.025677276775240898, 0.18113580346107483, 0.19019420444965363]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.25477269291877747, 0.28833284974098206, -0.25498083233833313, 0.40912926197052],
        [-0.387851357460022, 0.009837300516664982, -0.48930269479751587, -0.6119663715362549],
        [0.49769237637519836, -0.45746952295303345, -0.3886529505252838, -0.49895355105400085],
        [0.6451961994171143, 0.16054697334766388, 0.27802371978759766, -0.15226426720619202],
        [0.17125651240348816, -0.048851024359464645, 0.19429178535938263, 0.24933232367038727],
        [0.5465306043624878, -0.15836869180202484, 0.39782997965812683, -0.3635501563549042],
        [-0.36660289764404297, -0.011948992498219013, 0.48680511116981506, 0.5263928174972534],
        [-0.6284276843070984, -0.5880372524261475, 0.004470183979719877, -0.4550755023956299]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.7117368578910828]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.7743457555770874],
        [0.3977936804294586],
        [-1.0638943910598755],
        [-0.6494196653366089]
      ]
    >
  }
}

You can use event handlers to early-stop a loop or loop epoch by returning a :halt_* control term. Halt control terms can be one of :halt_epoch or :halt_loop. :halt_epoch halts the current epoch and continues to the next. :halt_loop halts the loop altogether.

defmodule CustomEventHandler do
  alias Axon.Loop.State

  def always_halts(%State{} = state) do
    IO.puts("stopping loop")
    {:halt_loop, state}
  end
end
{:module, CustomEventHandler, <<70, 79, 82, 49, 0, 0, 6, ...>>, {:always_halts, 1}}

The loop will immediately stop executing and return the current state at the time it was halted:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.handle(:epoch_completed, &CustomEventHandler.always_halts/1)
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 100, loss: 0.1967763
stopping loop
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.05958094820380211, 0.08930676430463791, -0.006259916350245476, 0.05067025125026703, 0.10981185734272003, -0.011248357594013214, -0.007601946126669645, 0.036958880722522736]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.050393108278512955, -0.5486620664596558, 0.6901980042457581, 0.42280837893486023, 0.6446300745010376, 0.25207778811454773, -0.13566234707832336, 0.26625606417655945]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.06729397922754288, 0.14259757101535797, -0.0020351663697510958, 0.16679106652736664]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.5964004397392273, -0.5631846785545349, 0.15613533556461334, 0.1943722516298294],
        [0.19513694941997528, -0.24765732884407043, -0.06751974672079086, 0.6707308292388916],
        [-0.6826592087745667, -0.006577506195753813, -0.6097249984741211, -0.5801466703414917],
        [-0.30076032876968384, 0.34819719195365906, -0.5906499028205872, -0.37741175293922424],
        [0.16266342997550964, 0.7666646838188171, 0.6456886529922485, -0.4589986801147461],
        [-0.2686948776245117, -0.06113003194332123, 0.22663049399852753, -0.12092678993940353],
        [-0.5785921216011047, -0.641874372959137, -0.24317769706249237, -0.2897084951400757],
        [0.14917287230491638, 0.24462535977363586, -0.64858478307724, -0.5138146877288818]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.11649220436811447]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.7849427461624146],
        [0.5966104865074158],
        [-0.5520159602165222],
        [-0.4974740147590637]
      ]
    >
  }
}

Note that halting an epoch will fire a different event than completing an epoch. So if you implement a custom handler to halt the loop when an epoch completes, it will never fire if the epoch always halts prematurely:

defmodule CustomEventHandler do
  alias Axon.Loop.State

  def always_halts_epoch(%State{} = state) do
    IO.puts("\nstopping epoch")
    {:halt_epoch, state}
  end

  def always_halts_loop(%State{} = state) do
    IO.puts("stopping loop\n")
    {:halt_loop, state}
  end
end
{:module, CustomEventHandler, <<70, 79, 82, 49, 0, 0, 7, ...>>, {:always_halts_loop, 1}}

If you run these handlers in conjunction, the loop will not terminate prematurely:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.handle(:iteration_completed, &CustomEventHandler.always_halts_epoch/1)
|> Axon.Loop.handle(:epoch_completed, &CustomEventHandler.always_halts_loop/1)
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 0, loss: 0.0000000
stopping epoch
Epoch: 0, Batch: 0, loss: 0.7256396
stopping epoch
Epoch: 0, Batch: 0, loss: 0.4574284
stopping epoch
Epoch: 0, Batch: 0, loss: 0.4981923
stopping epoch
Epoch: 0, Batch: 0, loss: 0.4377063
stopping epoch
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [9.248655405826867e-4, -0.0038722341414541006, -0.0015197680331766605, -0.001993122510612011, -0.0015419051051139832, -0.004070846363902092, 0.001461982261389494, 0.0043989671394228935]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.6537156701087952, 0.2857331335544586, -0.339731365442276, 0.46841081976890564, -0.5864744782447815, -0.364472359418869, -0.5385616421699524, -0.694677472114563]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.0, -0.017093738541007042, 0.00152371556032449, -0.0019599769730120897]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.21336764097213745, -0.6211493611335754, 0.676548957824707, 0.3768426477909088],
        [-0.24921125173568726, 0.217195525765419, 0.23704318702220917, 0.1597728431224823],
        [-0.12178827077150345, -0.4966273307800293, -0.283501535654068, 0.00888047181069851],
        [-0.19504092633724213, 0.18697738647460938, 0.14705461263656616, 0.39286476373672485],
        [-0.5945789813995361, -0.5958647727966309, -0.3320448100566864, -0.02747068926692009],
        [-0.2157520055770874, -0.2990635335445404, -0.16008871793746948, 0.4921063184738159],
        [-0.529068648815155, -0.383655846118927, -0.07292155921459198, -0.2834954559803009],
        [-0.3056498169898987, -0.28507867455482483, 0.554026186466217, -0.24665579199790955]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.010511377826333046]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.9865502119064331],
        [-0.686279296875],
        [-0.15436960756778717],
        [0.18355509638786316]
      ]
    >
  }
}

You may access and update any portion of the loop state. Keep in mind that event handlers are not JIT-compiled, so you should be certain to manually JIT-compile any long-running or expensive operations.