View Source Training and inference mode

Mix.install([
  {:axon, ">= 0.5.0"}
])
:ok

Executing models in inference mode

Some layers have different considerations and behavior when running during model training versus model inference. For example dropout layers are intended only to be used during training as a form of model regularization. Certain stateful layers like batch normalization keep a running-internal state which changes during training mode but remains fixed during inference mode. Axon supports mode-dependent execution behavior via the :mode option passed to all building, compilation, and execution methods. By default, all models build in inference mode. You can see this behavior by adding a dropout layer with a dropout rate of 1. In inference mode this layer will have no affect:

inputs = Nx.iota({2, 8}, type: :f32)

model =
  Axon.input("data")
  |> Axon.dense(4)
  |> Axon.sigmoid()
  |> Axon.dropout(rate: 0.99)
  |> Axon.dense(1)

{init_fn, predict_fn} = Axon.build(model)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][1]
  [
    [0.6900148391723633],
    [1.1159517765045166]
  ]
>

You can also explicitly specify the mode:

{init_fn, predict_fn} = Axon.build(model, mode: :inference)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
#Nx.Tensor<
  f32[2][1]
  [
    [-1.1250841617584229],
    [-1.161189317703247]
  ]
>

It's important that you know which mode your model's were compiled for, as running a model built in :inference mode will behave drastically different than a model built in :train mode.

Executing models in training mode

By specifying mode: :train, you tell your models to execute in training mode. You can see the effects of this behavior here:

{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
  prediction: #Nx.Tensor<
    f32[2][1]
    [
      [0.0],
      [0.0]
    ]
  >,
  state: %{
    "dropout_0" => %{
      "key" => #Nx.Tensor<
        u32[2]
        [309162766, 2699730300]
      >
    }
  }
}

First, notice that your model now returns a map with keys :prediction and :state. :prediction contains the actual model prediction, while :state contains the updated state for any stateful layers such as batch norm. When writing custom training loops, you should extract :state and use it in conjunction with the updates API to ensure your stateful layers are updated correctly. If your model has stateful layers, :state will look similar to your model's parameter map:

model =
  Axon.input("data")
  |> Axon.dense(4)
  |> Axon.sigmoid()
  |> Axon.batch_norm()
  |> Axon.dense(1)

{init_fn, predict_fn} = Axon.build(model, mode: :train)
params = init_fn.(inputs, %{})
predict_fn.(params, inputs)
%{
  prediction: #Nx.Tensor<
    f32[2][1]
    [
      [0.4891311526298523],
      [-0.4891311228275299]
    ]
  >,
  state: %{
    "batch_norm_0" => %{
      "mean" => #Nx.Tensor<
        f32[4]
        [0.525083601474762, 0.8689039349555969, 0.03931800276041031, 0.0021854371298104525]
      >,
      "var" => #Nx.Tensor<
        f32[4]
        [0.13831248879432678, 0.10107331722974777, 0.10170891880989075, 0.10000484436750412]
      >
    }
  }
}