CrucibleIR.Training.Config (CrucibleIR v0.2.1)

View Source

Configuration for model training.

Defines hyperparameters, optimizer settings, and training options for a training run.

Fields

  • :id - Config identifier (required)
  • :model_ref - Reference to model to train (required)
  • :dataset_ref - Training dataset (required)
  • :epochs - Number of training epochs
  • :batch_size - Batch size
  • :learning_rate - Initial learning rate
  • :optimizer - Optimizer type
  • :loss_function - Loss function
  • :metrics - Metrics to track
  • :validation_split - Validation data ratio
  • :device - Compute device
  • :seed - Random seed
  • :mixed_precision - Use mixed precision
  • :gradient_clipping - Max gradient norm
  • :early_stopping - Early stopping config
  • :checkpoint_every - Checkpoint frequency
  • :options - Additional options

Examples

iex> config = %CrucibleIR.Training.Config{
...>   id: :train_gpt2,
...>   model_ref: %CrucibleIR.ModelRef{id: :gpt2},
...>   dataset_ref: %CrucibleIR.DatasetRef{name: :wikitext},
...>   epochs: 10,
...>   batch_size: 32
...> }
iex> config.epochs
10

Summary

Types

device()

@type device() :: :cpu | :cuda | :mps | :tpu | atom()

loss()

@type loss() :: :cross_entropy | :mse | :mae | :bce | atom()

optimizer()

@type optimizer() :: :adam | :sgd | :adamw | :rmsprop | atom()

t()

@type t() :: %CrucibleIR.Training.Config{
  batch_size: pos_integer(),
  checkpoint_every: pos_integer() | nil,
  dataset_ref: CrucibleIR.DatasetRef.t(),
  device: device(),
  early_stopping: map() | nil,
  epochs: pos_integer(),
  gradient_clipping: float() | nil,
  id: atom(),
  learning_rate: float(),
  loss_function: loss(),
  metrics: [atom()],
  mixed_precision: boolean(),
  model_ref: CrucibleIR.ModelRef.t(),
  optimizer: optimizer(),
  options: map() | nil,
  seed: integer() | nil,
  validation_split: float() | nil
}