View Source Instrumenting loops with metrics
Mix.install([
{:axon, github: "elixir-nx/axon"},
{:nx, "~> 0.3.0", github: "elixir-nx/nx", sparse: "nx", override: true}
])
:ok
adding-metrics-to-training-loops
Adding metrics to training loops
Often times when executing a loop you want to keep track of various metrics such as accuracy or precision. For training loops, Axon by default only tracks loss; however, you can instrument the loop with additional built-in metrics. For example, you might want to track mean-absolute error on top of a mean-squared error loss:
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)
loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error)
#Axon.Loop<
handlers: %{
completed: [],
epoch_completed: [
{#Function<23.20267452/1 in Axon.Loop.log/5>,
#Function<5.20267452/1 in Axon.Loop.build_filter_fn/1>}
],
epoch_halted: [],
epoch_started: [],
halted: [],
iteration_completed: [
{#Function<23.20267452/1 in Axon.Loop.log/5>,
#Function<3.20267452/1 in Axon.Loop.build_filter_fn/1>}
],
iteration_started: [],
started: []
},
metrics: %{
"loss" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>,
#Function<6.20267452/2 in Axon.Loop.build_loss_fn/1>},
"mean_absolute_error" => {#Function<12.6031754/3 in Axon.Metrics.running_average/1>,
:mean_absolute_error}
},
...
>
When specifying a metric, you can specify an atom which maps to any of the metrics defined in Axon.Metrics
. You can also define custom metrics. For more information on custom metrics, see Writing custom metrics.
When you run a loop with metrics, Axon will aggregate that metric over the course of the loop execution. For training loops, Axon will also report the aggregate metric in the training logs:
train_data =
Stream.repeatedly(fn ->
xs = Nx.random_normal({8, 1})
ys = Nx.sin(xs)
{xs, ys}
end)
Axon.Loop.run(loop, train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.0646209 mean_absolute_error: 0.1720028
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.2462722808122635, 0.18984302878379822, 0.0016971784643828869, 0.19568635523319244, 0.33571094274520874, 0.07703055441379547, 0.29576605558395386, 0.14511419832706451]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.7807592749595642, -0.17303702235221863, 0.43004679679870605, -0.46043306589126587, -0.6577866077423096, 0.7490359544754028, -0.5164405703544617, -0.77418452501297]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.027583779767155647, 0.4279942214488983, -0.10632428526878357, -0.05149337649345398]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[-0.5688502192497253, -0.49978527426719666, 0.0660838857293129, 0.30804139375686646],
[0.21578946709632874, 0.4183472990989685, 0.530754566192627, 0.1742597073316574],
[-0.17872463166713715, -0.08955764025449753, -0.7048909664154053, 0.053243234753608704],
[-0.41064000129699707, 0.3491946756839752, 0.3753710091114044, 0.6630277037620544],
[-0.1781950145959854, 0.5766432881355286, 0.5829672813415527, -0.34879636764526367],
[-0.026939965784549713, -0.44429031014442444, -0.12619371712207794, 0.0030224998481571674],
[0.411702424287796, 0.3330642879009247, -0.5062007308006287, -0.0731467455625534],
[-0.41474586725234985, 0.23881299793720245, 0.3847745358943939, -0.5769480466842651]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.8004998564720154]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[-0.40993982553482056],
[-1.0208697319030762],
[0.18116380274295807],
[-0.8320646286010742]
]
>
}
}
By default, the metric will have a name which matches the string form of the given metric. You can give metrics semantic meaning by providing an explicit name:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "model error")
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.0559179 model error: 0.1430965
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[-0.2884136438369751, -0.016403740271925926, 0.30548375844955444, 0.2799474000930786, -0.017874717712402344, 0.3168976306915283, -0.10385002940893173, -0.18653006851673126]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.44000443816185, 0.6495574712753296, -0.5427255034446716, -0.795007050037384, -0.0035864184610545635, -0.5102121233940125, 0.10152970999479294, -0.3913733959197998]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[-0.24588409066200256, -0.05674195662140846, -0.08545850962400436, 0.27886852622032166]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[0.6334101557731628, -0.44550418853759766, 0.34385600686073303, 0.24886265397071838],
[-0.5474148988723755, 0.09881290793418884, 0.14616712927818298, 0.8087677359580994],
[-0.15381869673728943, 0.5322079658508301, -0.6275551915168762, -0.4207017421722412],
[0.4673740863800049, 0.5706797242164612, 0.44344833493232727, -0.5382705926895142],
[0.6662552356719971, -0.3875215947628021, -0.5359503626823425, -0.6198058724403381],
[-0.2842515707015991, 0.2379448264837265, 0.581102728843689, -0.5942302346229553],
[0.039275627583265305, 0.6341984272003174, -0.10589496046304703, -0.3522306978702545],
[0.4015151560306549, -0.15162920951843262, -0.3449919819831848, 0.21970798075199127]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[0.26691529154777527]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[0.7088357210159302],
[-0.9271859526634216],
[-0.1610293984413147],
[0.6011591553688049]
]
>
}
}
Axon's default aggregation behavior is to aggregate metrics with a running average; however, you can customize this behavior by specifying an explicit accumulation function. Built-in accumulation functions are :running_average
and :running_sum
:
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.metric(:mean_absolute_error, "total error", :running_sum)
|> Axon.Loop.run(train_data, %{}, iterations: 1000)
Epoch: 0, Batch: 1000, loss: 0.0645265 total error: 158.5873566
%{
"dense_0" => %{
"bias" => #Nx.Tensor<
f32[8]
[0.013307658955454826, 0.08766761422157288, -0.0048030223697423935, -0.07024712860584259, 0.261692613363266, 0.0028863451443612576, -0.12552864849567413, 0.10552618652582169]
>,
"kernel" => #Nx.Tensor<
f32[1][8]
[
[-0.1647171825170517, -0.4144238233566284, -0.09969457238912582, -0.6063833832740784, 0.7182243466377258, -0.3485015034675598, -0.29005324840545654, -0.5282242298126221]
]
>
},
"dense_1" => %{
"bias" => #Nx.Tensor<
f32[4]
[0.021465059369802475, -0.16003911197185516, 0.6696521043777466, -0.15482725203037262]
>,
"kernel" => #Nx.Tensor<
f32[8][4]
[
[0.3359515964984894, -0.21561087667942047, -0.48400720953941345, -0.3186679184436798],
[-0.08509980887174606, -0.031951334327459335, -0.6084564924240112, -0.39506790041923523],
[0.003889488521963358, -0.12886928021907806, 0.5679722428321838, 0.22699925303459167],
[-0.315458744764328, 0.5626247525215149, -0.4241454303264618, -0.11212264746427536],
[0.6759291291236877, -0.6508319973945618, 0.3511318564414978, 0.17946019768714905],
[-0.7148906588554382, 0.45404312014579773, 0.4150676727294922, 0.33603984117507935],
[0.398037314414978, 0.5080180764198303, 0.6770725250244141, -0.5274750590324402],
[0.5072763562202454, -0.7351003289222717, -0.583225429058075, -0.2974703013896942]
]
>
},
"dense_2" => %{
"bias" => #Nx.Tensor<
f32[1]
[-0.8310347199440002]
>,
"kernel" => #Nx.Tensor<
f32[4][1]
[
[0.28011587262153625],
[0.542819082736969],
[1.2814348936080933],
[-0.5193246603012085]
]
>
}
}