View Source Using loop event handlers

Mix.install([
  {:axon, ">= 0.5.0"}
])
:ok

Adding event handlers to training loops

Often times you want more fine-grained control over things that happen during loop execution. For example, you might want to save loop state to a file every 500 iterations, or log some output to :stdout at the end of every epoch. Axon loops allow more fine-grained control via events and event handlers.

Axon fires a number of events during loop execution which allow you to instrument various points in the loop execution cycle. You can attach event handlers to any of these events:

events = [
  :started,             # After loop state initialization
  :epoch_started,       # On epoch start
  :iteration_started,   # On iteration start
  :iteration_completed, # On iteration complete
  :epoch_completed,     # On epoch complete
  :epoch_halted,        # On epoch halt, if early halted
  :halted,              # On loop halt, if early halted
  :completed            # On loop completion
]

Axon packages a number of common loop event handlers for you out of the box. These handlers should cover most of the common event handlers you would need to write in practice. Axon also allows for custom event handlers. See Writing custom event handlers for more information.

An event handler will take the current loop state at the time of the fired event, and alter or use it in someway before returning control back to the main loop execution. You can attach any of Axon's pre-packaged event handlers to a loop by using the function directly. For example, if you want to checkpoint loop state at the end of every epoch, you can use Axon.Loop.checkpoint/2:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.checkpoint(event: :epoch_completed)
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<17.37390314/1 in Axon.Loop.checkpoint/2>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>},
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Now when you execute your loop, it will save a checkpoint at the end of every epoch:

train_data =
  Stream.repeatedly(fn ->
    {xs, _next_key} =
      :random.uniform(9999)
      |> Nx.Random.key()
      |> Nx.Random.normal(shape: {8, 1})

    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 50, loss: 0.5345965
Epoch: 1, Batch: 50, loss: 0.4578816
Epoch: 2, Batch: 50, loss: 0.4527244
Epoch: 3, Batch: 50, loss: 0.4466343
Epoch: 4, Batch: 50, loss: 0.4401709
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.1074252650141716, -0.0033432210329920053, -0.08044778555631638, 0.0016452680574730039, -0.01557128969579935, -0.061440952122211456, 0.061030879616737366, 0.012781506404280663]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.3504936695098877, 0.6722151041030884, -0.5550820231437683, 0.05254736915230751, 0.7404129505157471, -0.24307608604431152, -0.7073894739151001, 0.6447222828865051]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.19830459356307983, 0.0, 0.0, -0.04925372824072838]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.4873020648956299, -0.3363800644874573, -0.6058675050735474, -0.47888076305389404],
        [-0.18936580419540405, -0.5579301714897156, -0.49217337369918823, 0.04828363656997681],
        [0.3202762305736542, -0.033479928970336914, 0.11928367614746094, -0.5225698351860046],
        [0.3883931040763855, 0.07413274049758911, 0.548823893070221, -0.03494540974497795],
        [-0.2598196268081665, -0.4546756446361542, 0.5866180062294006, 0.2946240305900574],
        [0.2722054719924927, -0.5802338123321533, 0.4854300618171692, -0.5049118399620056],
        [-0.415179044008255, -0.5426293611526489, -0.1631108522415161, -0.6544353365898132],
        [-0.3079695403575897, 0.09391731023788452, -0.40262123942375183, -0.27837851643562317]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.016238097101449966]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.3102125823497772],
        [-1.078292727470398],
        [0.7910841703414917],
        [0.014510140754282475]
      ]
    >
  }
}

You can also use event handlers for things as simple as implementing custom logging with the pre-packaged Axon.Loop.log/4 event handler:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.log(fn _state -> "epoch is over\n" end, event: :epoch_completed, device: :stdio)
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 50, loss: 0.3220241
epoch is over
Epoch: 1, Batch: 50, loss: 0.2309804
epoch is over
Epoch: 2, Batch: 50, loss: 0.1759415
epoch is over
Epoch: 3, Batch: 50, loss: 0.1457551
epoch is over
Epoch: 4, Batch: 50, loss: 0.1247821
epoch is over
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.01846296526491642, -0.0016654117498546839, 0.39859917759895325, 0.21187178790569305, 0.08815062046051025, -0.11071830987930298, 0.06280634552240372, -0.11682439595460892]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.08840499818325043, 0.44253841042518616, -0.6063749194145203, -0.1487167924642563, 0.24857401847839355, 0.1697462797164917, -0.5370600819587708, 0.1658734828233719]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.08111556619405746, 0.32310858368873596, -0.059386227279901505, -0.09515857696533203]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.6057762503623962, -0.2633209824562073, 0.23028653860092163, -0.2710704505443573],
        [0.03961030766367912, -0.335278183221817, 0.16016681492328644, 0.10653878003358841],
        [0.36239713430404663, 0.8330743312835693, 0.4745633602142334, -0.29585230350494385],
        [-0.04394621402025223, 0.45401355624198914, 0.5953336954116821, -0.6513576507568359],
        [-0.6447072625160217, -0.6225455403327942, -0.4814218580722809, 0.6882413625717163],
        [-0.44460421800613403, -0.04251839220523834, 0.4619944095611572, 0.24515877664089203],
        [-0.49396005272865295, -0.08895684778690338, 0.5212237238883972, 0.24301064014434814],
        [0.3074108958244324, 0.2640342712402344, 0.4197620749473572, -0.05698487162590027]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.6520459651947021]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.45083022117614746],
        [-0.8733288049697876],
        [-0.1894296556711197],
        [0.030911535024642944]
      ]
    >
  }
}

For even more fine-grained control over when event handlers fire, you can add filters. For example, if you only want to checkpoint loop state every 2 epochs, you can use a filter:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.checkpoint(event: :epoch_completed, filter: [every: 2])
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 50, loss: 0.3180207
Epoch: 1, Batch: 50, loss: 0.1975918
Epoch: 2, Batch: 50, loss: 0.1353940
Epoch: 3, Batch: 50, loss: 0.1055405
Epoch: 4, Batch: 50, loss: 0.0890203
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.047411054372787476, 0.1582564115524292, -0.027924394235014915, 0.1774083375930786, 0.09764095395803452, 0.1040089949965477, 0.006841400172561407, -0.11682236939668655]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.20366023480892181, 0.7318703532218933, -0.028611917048692703, -0.5324040055274963, -0.6856501698493958, 0.21694214642047882, 0.3281741738319397, -0.13051153719425201]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.1859581470489502, 0.3360026180744171, 0.24061667919158936, -0.016354668885469437]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.07366377860307693, -0.3261552155017853, -0.6951385140419006, -0.4232194125652313],
        [0.7334840893745422, -0.17827139794826508, -0.6411628127098083, -0.41898131370544434],
        [0.4770638346672058, -0.4738321304321289, 0.5755389332771301, 0.30976954102516174],
        [-0.498087614774704, 0.10546410828828812, 0.690037190914154, -0.5016340613365173],
        [0.17509347200393677, 0.4518563449382782, -0.10358063131570816, 0.2223401516675949],
        [0.6422480344772339, 0.19363932311534882, 0.2870054543018341, -0.1483648419380188],
        [-0.10362248122692108, -0.7047968506813049, 0.02847556211054325, -0.18464618921279907],
        [-0.6756409406661987, -0.42686882615089417, -0.5484509468078613, 0.596512496471405]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.23296000063419342]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.48827823996543884],
        [-0.7908728122711182],
        [-0.5326805114746094],
        [0.3789232671260834]
      ]
    >
  }
}

Axon event handlers support both keyword and function filters. Keyword filters include keywords such as :every, :once, and :always. Function filters are arity-1 functions which accept the current loop state and return a boolean.