Soothsayer.Trainer (Soothsayer v0.6.1)
View SourceTraining functionality for Soothsayer models.
Handles standard training and custom training with L1 regularization on specified layer weights (AR, trend, etc.).
Summary
Functions
@spec compute_l1_penalty(Axon.ModelState.t(any(), any()), [String.t()]) :: Nx.Tensor.t()
Computes L1 penalty for specified layer kernels.
Parameters
params- AnAxon.ModelStatecontaining the model parameters.layer_names- A list of layer names to include in the penalty.
Returns
A scalar tensor with the sum of absolute values of all kernel weights in the specified layers.
Examples
iex> penalty = Soothsayer.Trainer.compute_l1_penalty(params, ["ar_dense_out"])
#Nx.Tensor<f32 6.0>
@spec fit( Axon.t(), %{required(String.t()) => Nx.Tensor.t()}, Nx.Tensor.t(), non_neg_integer(), map() ) :: Axon.ModelState.t(any(), any())
Trains a network on the provided data.
Parameters
network- An Axon neural network.x- A map of input tensors.y- A tensor of target values.epochs- The number of training epochs.config- A map containing training configuration::learning_rate- Learning rate for the optimizer.:ar- Optional map with:regularizationfor AR L1 penalty.:trend- Optional map with:regularizationfor trend L1 penalty.
Returns
The trained Axon.ModelState.
Examples
iex> config = %{learning_rate: 0.1}
iex> params = Soothsayer.Trainer.fit(network, x, y, 100, config)
%Axon.ModelState{...}