View Source Writing custom metrics

Mix.install([
  {:axon, ">= 0.5.0"}
])
:ok

Writing custom metrics

When passing an atom to Axon.Loop.metric/5, Axon dispatches the function to a built-in function in Axon.Metrics. If you find you'd like to use a metric that does not exist in Axon.Metrics, you can define a custom function:

defmodule CustomMetric do
  import Nx.Defn

  defn my_weird_metric(y_true, y_pred) do
    Nx.atan2(y_true, y_pred) |> Nx.sum()
  end
end
{:module, CustomMetric, <<70, 79, 82, 49, 0, 0, 8, ...>>, true}

Then you can pass that directly to Axon.Loop.metric/5. You must provide a name for your custom metric:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(&CustomMetric.my_weird_metric/2, "my weird metric")
#Axon.Loop<
  metrics: %{
    "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>},
    "my weird metric" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     &CustomMetric.my_weird_metric/2}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Then when running, Axon will invoke your custom metric function and accumulate it with the given aggregator:

train_data =
  Stream.repeatedly(fn ->
    {xs, _next_key} =
      :random.uniform(9999)
      |> Nx.Random.key()
      |> Nx.Random.normal(shape: {8, 1})

    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, loss: 0.0681635 my weird metric: -5.2842808
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.0866982489824295, 0.4234408140182495, 0.18205422163009644, 0.34029239416122437, -0.25770726799964905, -0.07117943465709686, 0.11470477283000946, -0.027526771649718285]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.7088809013366699, 0.4486531913280487, 0.4666421115398407, 0.4163222312927246, 0.5076444149017334, 0.10119977593421936, 0.6628422141075134, -0.024421442300081253]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.2924745976924896, 0.0065560233779251575, 0.0, -0.21106423437595367]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.3407173752784729, -0.6905813217163086, -0.5984221696853638, -0.23955762386322021],
        [0.42608022689819336, 0.5949274301528931, -0.24687853455543518, -0.4948572516441345],
        [0.27617380023002625, -0.44326621294021606, -0.5848686099052429, 0.31592807173728943],
        [0.5401414632797241, -0.1041281446814537, -0.4072037935256958, 0.4387882947921753],
        [-0.5410752892494202, 0.4544697403907776, -0.6238576173782349, -0.2077195793390274],
        [-0.41753143072128296, -0.11599045991897583, -0.22447934746742249, -0.5805748701095581],
        [0.1651047021150589, -0.526184618473053, 0.34729963541030884, 0.3307822048664093],
        [0.6879482865333557, 0.27184563875198364, -0.4907835125923157, -0.3555335998535156]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.8146252036094666]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [1.2187021970748901],
        [0.13001228868961334],
        [0.2703772783279419],
        [-0.3591017723083496]
      ]
    >
  }
}

While the metric defaults are designed with supervised training loops in mind, they can be used for much more flexible purposes. By default, metrics look for the fields :y_true and :y_pred in the given loop's step state. They then apply the given metric function on those inputs. You can also define metrics which work on other fields. For example you can track the running average of a given parameter with a metric just by defining a custom output transform:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

output_transform = fn %{model_state: model_state} ->
  [model_state["dense_0"]["kernel"]]
end

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(&Nx.mean/1, "dense_0_kernel_mean", :running_average, output_transform)
  |> Axon.Loop.metric(&Nx.variance/1, "dense_0_kernel_var", :running_average, output_transform)
#Axon.Loop<
  metrics: %{
    "dense_0_kernel_mean" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     &Nx.mean/1},
    "dense_0_kernel_var" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     &Nx.variance/1},
    "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Axon will apply your custom output transform to the loop's step state and forward the result to your custom metric function:

train_data =
  Stream.repeatedly(fn ->
    {xs, _next_key} =
      :random.uniform(9999)
      |> Nx.Random.key()
      |> Nx.Random.normal(shape: {8, 1})

    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, dense_0_kernel_mean: -0.1978206 dense_0_kernel_var: 0.2699870 loss: 0.0605523
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [0.371105819940567, 0.26451945304870605, -0.048297226428985596, 0.14616385102272034, -0.19356133043766022, -0.2924956679344177, 0.08295489847660065, 0.25213995575904846]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [-0.3888320028781891, -0.39463144540786743, 0.5427617430686951, -0.776488721370697, -0.2402891218662262, -0.6489362716674805, 0.772796094417572, -0.3739306926727295]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.0, -0.006653765682131052, 0.0, 0.3086839020252228]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.5556576251983643, 0.5547546148300171, -0.2708005905151367, 0.7341570258140564],
        [-0.01800161600112915, 0.19749529659748077, -0.09523773193359375, 0.4989740252494812],
        [-0.19737857580184937, -0.2741832435131073, -0.3699955344200134, 0.21036939322948456],
        [-0.09787613153457642, -0.5631319284439087, 0.007957160472869873, 0.23681949079036713],
        [-0.469108909368515, 0.24062377214431763, -0.012939095497131348, -0.5055088400840759],
        [0.11229842901229858, -0.5476430058479309, 0.013744592666625977, -0.631401538848877],
        [-0.5834296941757202, -0.42305096983909607, 0.1393480896949768, -0.4647532105445862],
        [-0.3684111535549164, -0.5147689580917358, -0.3725535273551941, 0.46682292222976685]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [0.8305950164794922]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.7111979722976685],
        [-0.49341335892677307],
        [-0.32701319456100464],
        [-1.0638068914413452]
      ]
    >
  }
}

You can also define custom accumulation functions. Axon has definitions for computing running averages and running sums; however, you might find you need something like an exponential moving average:

defmodule CustomAccumulator do
  import Nx.Defn

  defn running_ema(acc, obs, _i, opts \\ []) do
    opts = keyword!(opts, alpha: 0.9)
    obs * opts[:alpha] + acc * (1 - opts[:alpha])
  end
end
{:module, CustomAccumulator, <<70, 79, 82, 49, 0, 0, 11, ...>>, true}

Your accumulator must be an arity-3 function which accepts the current accumulated value, the current observation, and the current iteration and returns the aggregated metric. You can pass a function direct as an accumulator in your metric:

model =
  Axon.input("data")
  |> Axon.dense(8)
  |> Axon.relu()
  |> Axon.dense(4)
  |> Axon.relu()
  |> Axon.dense(1)

output_transform = fn %{model_state: model_state} ->
  [model_state["dense_0"]["kernel"]]
end

loop =
  model
  |> Axon.Loop.trainer(:mean_squared_error, :sgd)
  |> Axon.Loop.metric(
    &Nx.mean/1,
    "dense_0_kernel_ema_mean",
    &CustomAccumulator.running_ema/3,
    output_transform
  )
#Axon.Loop<
  metrics: %{
    "dense_0_kernel_ema_mean" => {#Function<15.37390314/3 in Axon.Loop.build_metric_fn/3>,
     &Nx.mean/1},
    "loss" => {#Function<11.133813849/3 in Axon.Metrics.running_average/1>,
     #Function<9.37390314/2 in Axon.Loop.build_loss_fn/1>}
  },
  handlers: %{
    completed: [],
    epoch_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<6.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    epoch_halted: [],
    epoch_started: [],
    halted: [],
    iteration_completed: [
      {#Function<27.37390314/1 in Axon.Loop.log/3>,
       #Function<64.37390314/2 in Axon.Loop.build_filter_fn/1>}
    ],
    iteration_started: [],
    started: []
  },
  ...
>

Then when you run the loop, Axon will use your custom accumulator:

train_data =
  Stream.repeatedly(fn ->
    {xs, _next_key} =
      :random.uniform(9999)
      |> Nx.Random.key()
      |> Nx.Random.normal(shape: {8, 1})

    ys = Nx.sin(xs)
    {xs, ys}
  end)

Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 950, dense_0_kernel_ema_mean: -0.0139760 loss: 0.0682910
%{
  "dense_0" => %{
    "bias" => #Nx.Tensor<
      f32[8]
      [-0.3344854414463043, -0.14519920945167542, 0.1061621680855751, 0.36911827325820923, 0.014146199449896812, 0.46089673042297363, -0.1707312911748886, -0.054649338126182556]
    >,
    "kernel" => #Nx.Tensor<
      f32[1][8]
      [
        [0.6524605751037598, -0.3795280158519745, -0.2069108486175537, 0.6815686821937561, -0.5734748840332031, 0.5515486001968384, -0.13509605824947357, -0.711794912815094]
      ]
    >
  },
  "dense_1" => %{
    "bias" => #Nx.Tensor<
      f32[4]
      [0.3078235387802124, -0.24773009121418, -0.027328377589583397, 0.0769796073436737]
    >,
    "kernel" => #Nx.Tensor<
      f32[8][4]
      [
        [-0.785156786441803, 0.07306647300720215, 0.339533269405365, -0.2188076674938202],
        [0.29139244556427, 0.15977036952972412, 0.6193944215774536, -0.4305708408355713],
        [-0.21063144505023956, -0.3738138973712921, -0.27965712547302246, 0.051842525601387024],
        [0.7297297716140747, -0.08164620399475098, 0.07651054859161377, -0.43577027320861816],
        [0.07917583733797073, -0.27750709652900696, 0.21028375625610352, -0.6430750489234924],
        [0.7177602648735046, -0.2743614912033081, -0.5894488096237183, 0.634209156036377],
        [0.4251592457294464, 0.6134526133537292, -0.35339266061782837, 0.4966743588447571],
        [-0.49672019481658936, 0.46769094467163086, -0.44432300329208374, -0.3249942660331726]
      ]
    >
  },
  "dense_2" => %{
    "bias" => #Nx.Tensor<
      f32[1]
      [-0.8245151042938232]
    >,
    "kernel" => #Nx.Tensor<
      f32[4][1]
      [
        [0.9500011205673218],
        [0.9115968942642212],
        [0.39282673597335815],
        [0.19936752319335938]
      ]
    >
  }
}