Soothsayer.Trainer (Soothsayer v0.6.1)

View Source

Training functionality for Soothsayer models.

Handles standard training and custom training with L1 regularization on specified layer weights (AR, trend, etc.).

Summary

Functions

Computes L1 penalty for specified layer kernels.

Trains a network on the provided data.

Functions

compute_l1_penalty(params, layer_names)

@spec compute_l1_penalty(Axon.ModelState.t(any(), any()), [String.t()]) ::
  Nx.Tensor.t()

Computes L1 penalty for specified layer kernels.

Parameters

  • params - An Axon.ModelState containing the model parameters.
  • layer_names - A list of layer names to include in the penalty.

Returns

A scalar tensor with the sum of absolute values of all kernel weights in the specified layers.

Examples

iex> penalty = Soothsayer.Trainer.compute_l1_penalty(params, ["ar_dense_out"])
#Nx.Tensor<f32 6.0>

fit(network, x, y, epochs, config)

@spec fit(
  Axon.t(),
  %{required(String.t()) => Nx.Tensor.t()},
  Nx.Tensor.t(),
  non_neg_integer(),
  map()
) ::
  Axon.ModelState.t(any(), any())

Trains a network on the provided data.

Parameters

  • network - An Axon neural network.
  • x - A map of input tensors.
  • y - A tensor of target values.
  • epochs - The number of training epochs.
  • config - A map containing training configuration:
    • :learning_rate - Learning rate for the optimizer.
    • :ar - Optional map with :regularization for AR L1 penalty.
    • :trend - Optional map with :regularization for trend L1 penalty.

Returns

The trained Axon.ModelState.

Examples

iex> config = %{learning_rate: 0.1}
iex> params = Soothsayer.Trainer.fit(network, x, y, 100, config)
%Axon.ModelState{...}