View Source Using loop event handlers

Mix.install([
  {:axon, github: "elixir-nx/axon"},
  {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}
])
:ok

adding-event-handlers-to-training-loops

Adding event handlers to training loops

Often times you want more fine-grained control over things that happen during loop execution. For example, you might want to save loop state to a file every 500 iterations, or log some output to :stdout at the end of every epoch. Axon loops allow more fine-grained control via events and event handlers.

Axon fires a number of events during loop execution which allow you to instrument various points in the loop execution cycle. You can attach event handlers to any of these events:

events = [
  :started,             # After loop state initialization
  :epoch_started,       # On epoch start
  :iteration_started,   # On iteration start
  :iteration_completed, # On iteration complete
  :epoch_completed,     # On epoch complete
  :epoch_halted,        # On epoch halt, if early halted
  :halted,              # On loop halt, if early halted
  :completed            # On loop completion
]

Axon packages a number of common loop event handlers for you out of the box. These handlers should cover most of the common event handlers you would need to write in practice. Axon also allows for custom event handlers. See Writing custom event handlers for more information.

An event handler will take the current loop state at the time of the fired event, and alter or use it in someway before returning control back to the main loop execution. You can attach any of Axon's pre-packaged event handlers to a loop by using the function directly. For example, if you want to checkpoint loop state at the end of every epoch, you can use Axon.Loop.checkpoint/2:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.checkpoint(event: :epoch_completed)
#Axon.Loop<
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<14.20267452/1 in Axon.Loop.checkpoint/2>,
       #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>},
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  metrics: %{
    "loss" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>,
     #Function<6.20267452/2 in Axon.Loop.build_loss_fn/1>}
  },
  ...
>

Now when you execute your loop, it will save a checkpoint at the end of every epoch:

train_data =
  Stream.repeatedly(fn ->
    xs = Nx.random_normal({8, 1})
    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 100, loss: 0.2462310
Epoch: 1, Batch: 100, loss: 0.1804814
Epoch: 2, Batch: 100, loss: 0.1452925
Epoch: 3, Batch: 100, loss: 0.1177117
Epoch: 4, Batch: 100, loss: 0.1008184
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.36853691935539246, 0.24528849124908447, 0.13193830847740173, 0.03188902884721756, -0.06358373910188675, 0.044517479836940765, -0.1203451156616211, -6.352089694701135e-4]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.49448737502098083, 0.5250089764595032, 0.7132464051246643, 0.47473379969596863, -0.043285828083753586, -0.14137212932109833, -0.07576408237218857, -0.48898136615753174]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.30324652791023254, 0.0385407879948616, -0.16782516241073608, 0.1984063982963562]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.2536502778530121, 0.375381737947464, 0.7119463086128235, -0.14521682262420654],
        [0.20504063367843628, -0.11605211347341537, 0.49423739314079285, -0.03246872499585152],
        [-0.13834621012210846, -0.2579476833343506, 0.34836748242378235, -0.4670639634132385],
        [-0.11925031989812851, -0.6655324697494507, 0.5057039856910706, 0.496115118265152],
        [0.15856991708278656, -0.2239169478416443, 0.5550385117530823, -0.3774339258670807],
        [-0.326529860496521, -0.10192928463220596, 0.2961374819278717, 0.580808699131012],
        [0.46179524064064026, -0.4794206917285919, 0.47078272700309753, -0.5654175877571106],
        [-0.501025915145874, -0.38049301505088806, 0.3792027235031128, 0.685397207736969]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.4034360647201538]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.8062413334846497],
        [0.6867087483406067],
        [0.5137255787849426],
        [-0.5783006548881531]
      ]
    >
  }
}

You can also use event handlers for things as simple as implementing custom logging with the pre-packaged Axon.Loop.log/4 event handler:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.log(:epoch_completed, fn _state -> "epoch is over\n" end, :stdio)
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 100, loss: 0.2134880
epoch is over
Epoch: 1, Batch: 100, loss: 0.1604774
epoch is over
Epoch: 2, Batch: 100, loss: 0.1294429
epoch is over
Epoch: 3, Batch: 100, loss: 0.1087099
epoch is over
Epoch: 4, Batch: 100, loss: 0.0940388
epoch is over
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.1741544008255005, -0.013307991437613964, 0.0873112753033638, -0.04722493514418602, -0.12966567277908325, 0.04596322402358055, 0.3969370722770691, -0.04508184269070625]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.31960299611091614, -0.5328841805458069, -0.24278149008750916, -0.47772416472435, 0.21538947522640228, -0.2799384295940399, 0.5947694778442383, 0.0497460775077343]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.25857725739479065, -0.07283111661672592, -0.10656370222568512, -0.08234459906816483]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.3983175754547119, -0.5524351596832275, 0.36650899052619934, -0.23933114111423492],
        [0.06517457216978073, 0.2564122974872589, 0.6227137446403503, -0.5661884546279907],
        [-0.7012182474136353, 0.054501600563526154, -0.6726318597793579, 0.4774037301540375],
        [-0.11393500864505768, 0.1726256012916565, -0.6723376512527466, 0.6044175028800964],
        [-0.30502673983573914, 0.7011693120002747, 0.40034061670303345, -0.5748327374458313],
        [-0.07724377512931824, -0.251364529132843, -0.6626797914505005, -0.20940908789634705],
        [0.7290927767753601, 0.08563250303268433, -0.047927819192409515, -0.04336162284016609],
        [-0.34993213415145874, 0.281339168548584, -0.49343380331993103, -0.2481663078069687]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.6856028437614441]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [1.1966136693954468],
        [-0.00546963419765234],
        [-0.9349364042282104],
        [0.9214714765548706]
      ]
    >
  }
}

For even more fine-grained control over when event handlers fire, you can add filters. For example, if you only want to checkpoint loop state every 2 epochs, you can use a filter:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.checkpoint(event: :epoch_completed, filter: [every: 2])
|> Axon.Loop.run(train_data, %{}, epochs: 5, iterations: 100)
Epoch: 0, Batch: 100, loss: 0.1791917
Epoch: 1, Batch: 100, loss: 0.1373887
Epoch: 2, Batch: 100, loss: 0.1156979
Epoch: 3, Batch: 100, loss: 0.0965481
Epoch: 4, Batch: 100, loss: 0.0865761
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.00938357226550579, 0.16315333545207977, 0.2767408788204193, -0.22733710706233978, 0.2830233573913574, -0.10280115902423859, -0.07500249892473221, 0.2947545647621155]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.522411048412323, 0.15686289966106415, 0.30727216601371765, 0.3295647203922272, 0.38795727491378784, 0.17159366607666016, 0.7608513236045837, 0.4526905119419098]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.024011338129639626, 0.0, -0.00135718728415668, -0.0015321056125685573]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.606391966342926, -0.08385708928108215, 0.06838012486696243, -0.08704598248004913],
        [0.5944894552230835, -0.17639528214931488, 0.26653605699539185, 0.35148826241493225],
        [-0.06138936057686806, -0.024123376235365868, 0.29706713557243347, 0.5498997569084167],
        [0.26888611912727356, 0.024979088455438614, -0.653775155544281, -0.4111217260360718],
        [-0.5042538046836853, -0.6867390871047974, 0.13647332787513733, 0.7193269729614258],
        [-0.052732646465301514, 0.099549300968647, -0.6970457434654236, 0.3078557252883911],
        [-0.261769562959671, 0.17121906578540802, -0.08267408609390259, -0.2213396430015564],
        [-0.09766292572021484, -0.5843542218208313, 0.369784414768219, 0.48434120416641235]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.6914201378822327]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.96906977891922],
        [-0.5032458901405334],
        [0.9275273680686951],
        [0.8574270606040955]
      ]
    >
  }
}

Axon event handlers support both keyword and function filters. Keyword filters include keywords such as :every, :once, and :always. Function filters are arity-1 functions which accept the current loop state and return a boolean.