View Source Instrumenting loops with metrics

Mix.install([
  {:axon, github: "elixir-nx/axon"},
  {:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}
])
:ok

adding-metrics-to-training-loops

Adding metrics to training loops

Often times when executing a loop you want to keep track of various metrics such as accuracy or precision. For training loops, Axon by default only tracks loss; however, you can instrument the loop with additional built-in metrics. For example, you might want to track mean-absolute error on top of a mean-squared error loss:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(:mean_absolute_error)
#Axon.Loop<
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<23.20267452/1 in Axon.Loop.log/5>,
       #Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  metrics: %{
    "loss" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>,
     #Function<6.20267452/2 in Axon.Loop.build_loss_fn/1>},
    "mean_absolute_error" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>,
     :mean_absolute_error}
  },
  ...
>

When specifying a metric, you can specify an atom which maps to any of the metrics defined in Axon.Metrics. You can also define custom metrics. For more information on custom metrics, see Writing custom metrics.

When you run a loop with metrics, Axon will aggregate that metric over the course of the loop execution. For training loops, Axon will also report the aggregate metric in the training logs:

train_data =
  Stream.repeatedly(fn ->
    xs = Nx.random_normal({8, 1})
    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.0646209 mean_absolute_error: 0.1720028
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.2462722808122635, 0.18984302878379822, 0.0016971784643828869, 0.19568635523319244, 0.33571094274520874, 0.07703055441379547, 0.29576605558395386, 0.14511419832706451]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.7807592749595642, -0.17303702235221863, 0.43004679679870605, -0.46043306589126587, -0.6577866077423096, 0.7490359544754028, -0.5164405703544617, -0.77418452501297]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.027583779767155647, 0.4279942214488983, -0.10632428526878357, -0.05149337649345398]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.5688502192497253, -0.49978527426719666, 0.0660838857293129, 0.30804139375686646],
        [0.21578946709632874, 0.4183472990989685, 0.530754566192627, 0.1742597073316574],
        [-0.17872463166713715, -0.08955764025449753, -0.7048909664154053, 0.053243234753608704],
        [-0.41064000129699707, 0.3491946756839752, 0.3753710091114044, 0.6630277037620544],
        [-0.1781950145959854, 0.5766432881355286, 0.5829672813415527, -0.34879636764526367],
        [-0.026939965784549713, -0.44429031014442444, -0.12619371712207794, 0.0030224998481571674],
        [0.411702424287796, 0.3330642879009247, -0.5062007308006287, -0.0731467455625534],
        [-0.41474586725234985, 0.23881299793720245, 0.3847745358943939, -0.5769480466842651]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.8004998564720154]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [-0.40993982553482056],
        [-1.0208697319030762],
        [0.18116380274295807],
        [-0.8320646286010742]
      ]
    >
  }
}

By default, the metric will have a name which matches the string form of the given metric. You can give metrics semantic meaning by providing an explicit name:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "model error")
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.0559179 model error: 0.1430965
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.2884136438369751, -0.016403740271925926, 0.30548375844955444, 0.2799474000930786, -0.017874717712402344, 0.3168976306915283, -0.10385002940893173, -0.18653006851673126]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.44000443816185, 0.6495574712753296, -0.5427255034446716, -0.795007050037384, -0.0035864184610545635, -0.5102121233940125, 0.10152970999479294, -0.3913733959197998]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [-0.24588409066200256, -0.05674195662140846, -0.08545850962400436, 0.27886852622032166]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.6334101557731628, -0.44550418853759766, 0.34385600686073303, 0.24886265397071838],
        [-0.5474148988723755, 0.09881290793418884, 0.14616712927818298, 0.8087677359580994],
        [-0.15381869673728943, 0.5322079658508301, -0.6275551915168762, -0.4207017421722412],
        [0.4673740863800049, 0.5706797242164612, 0.44344833493232727, -0.5382705926895142],
        [0.6662552356719971, -0.3875215947628021, -0.5359503626823425, -0.6198058724403381],
        [-0.2842515707015991, 0.2379448264837265, 0.581102728843689, -0.5942302346229553],
        [0.039275627583265305, 0.6341984272003174, -0.10589496046304703, -0.3522306978702545],
        [0.4015151560306549, -0.15162920951843262, -0.3449919819831848, 0.21970798075199127]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.26691529154777527]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.7088357210159302],
        [-0.9271859526634216],
        [-0.1610293984413147],
        [0.6011591553688049]
      ]
    >
  }
}

Axon's default aggregation behavior is to aggregate metrics with a running average; however, you can customize this behavior by specifying an explicit accumulation function. Built-in accumulation functions are :running_average and :running_sum:

model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "total error", :running_sum)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.0645265 total error: 158.5873566
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.013307658955454826, 0.08766761422157288, -0.0048030223697423935, -0.07024712860584259, 0.261692613363266, 0.0028863451443612576, -0.12552864849567413, 0.10552618652582169]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.1647171825170517, -0.4144238233566284, -0.09969457238912582, -0.6063833832740784, 0.7182243466377258, -0.3485015034675598, -0.29005324840545654, -0.5282242298126221]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.021465059369802475, -0.16003911197185516, 0.6696521043777466, -0.15482725203037262]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [0.3359515964984894, -0.21561087667942047, -0.48400720953941345, -0.3186679184436798],
        [-0.08509980887174606, -0.031951334327459335, -0.6084564924240112, -0.39506790041923523],
        [0.003889488521963358, -0.12886928021907806, 0.5679722428321838, 0.22699925303459167],
        [-0.315458744764328, 0.5626247525215149, -0.4241454303264618, -0.11212264746427536],
        [0.6759291291236877, -0.6508319973945618, 0.3511318564414978, 0.17946019768714905],
        [-0.7148906588554382, 0.45404312014579773, 0.4150676727294922, 0.33603984117507935],
        [0.398037314414978, 0.5080180764198303, 0.6770725250244141, -0.5274750590324402],
        [0.5072763562202454, -0.7351003289222717, -0.583225429058075, -0.2974703013896942]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.8310347199440002]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.28011587262153625],
        [0.542819082736969],
        [1.2814348936080933],
        [-0.5193246603012085]
      ]
    >
  }
}