Object.DistributedTraining (object v0.1.2)

Distributed Low-Communication (DiLoCo) Training implementation for AAOS.

Implements the DiLoCo algorithm from "DiLoCo: Distributed Low-Communication Training of Language Models" (Douillard et al., 2023) as a specialized Object subtype within the AAOS framework.

This module enables training of large language models and other neural networks across islands of devices that are poorly connected, requiring minimal communication while maintaining performance comparable to fully synchronous training.

Key Features

  • Low Communication: Communicates only every H steps (hundreds/thousands)
  • Federated Architecture: Each worker operates on its own data island
  • Fault Tolerance: Byzantine resistance and graceful degradation
  • Heterogeneous Hardware: Different islands can use different device types
  • Performance: 500x less communication than synchronous training
  • Integration: Full AAOS object lifecycle and coordination

DiLoCo Algorithm

The algorithm consists of two optimization loops:

  1. Inner Optimization: Local AdamW updates for H steps
  2. Outer Optimization: Global parameter averaging with Nesterov momentum

Mathematical Foundation

For T outer steps and H inner steps per worker:

  • Total training steps: N = T × H
  • Communication frequency: Every H steps
  • Workers: k islands of devices
  • Outer gradient: Δ^(t) = (1/k) ∑ᵢ (θ^(t-1) - θᵢ^(t))

Summary

Functions

Returns a specification to start this module under a supervisor.

Initializes a training coalition of distributed workers.

Gets current training metrics.

Performs H inner training steps on local data.

Loads a training checkpoint.

Creates a new DiLoCo distributed training object.

Performs a single outer training step (T outer iterations).

Saves a training checkpoint.

Starts the distributed training object as a GenServer.

Synchronizes with other workers in the coalition.

Executes the DiLoCo training algorithm.

Types

communication_state()

@type communication_state() :: %{
  last_sync: DateTime.t(),
  pending_gradients: map(),
  sync_barrier_count: integer(),
  communication_overhead: float(),
  bandwidth_usage: float()
}

consensus_state()

@type consensus_state() :: %{
  algorithm: :pbft | :raft | :practical_bft,
  view_number: integer(),
  leader: worker_id() | nil,
  votes: %{required(worker_id()) => vote()},
  committed_steps: integer()
}

data_shard()

@type data_shard() :: %{
  shard_id: String.t(),
  data_path: String.t(),
  total_samples: integer(),
  current_position: integer(),
  preprocessing_config: map()
}

fault_tolerance_state()

@type fault_tolerance_state() :: %{
  failed_workers: [worker_id()],
  backup_checkpoints: %{required(String.t()) => model_state()},
  consensus_state: consensus_state(),
  health_status: :healthy | :degraded | :critical
}

model_state()

@type model_state() :: %{
  parameters: %{required(String.t()) => tensor()},
  metadata: %{
    architecture: String.t(),
    layer_count: integer(),
    parameter_count: integer(),
    last_updated: DateTime.t()
  }
}

optimizer_config()

@type optimizer_config() :: %{
  inner_optimizer: :adamw | :adam | :sgd,
  outer_optimizer: :nesterov_momentum | :sgd | :adam,
  learning_rate: float(),
  momentum: float(),
  weight_decay: float(),
  beta1: float(),
  beta2: float(),
  epsilon: float()
}

optimizer_state()

@type optimizer_state() :: %{
  type: atom(),
  state: map(),
  step_count: integer(),
  accumulated_gradients: map()
}

performance_metrics()

@type performance_metrics() :: %{
  training_loss: float(),
  validation_loss: float(),
  throughput: float(),
  communication_efficiency: float(),
  convergence_rate: float(),
  wall_clock_time: integer(),
  compute_utilization: float()
}

step_counters()

@type step_counters() :: %{
  inner_step: integer(),
  outer_step: integer(),
  total_steps: integer(),
  communication_rounds: integer()
}

synchronization_barrier()

@type synchronization_barrier() :: %{
  barrier_id: String.t(),
  expected_workers: [worker_id()],
  arrived_workers: [worker_id()],
  timeout: integer(),
  start_time: DateTime.t()
}

t()

@type t() :: %Object.DistributedTraining{
  communication_state: communication_state(),
  coordination_service: pid(),
  data_shard: data_shard(),
  fault_tolerance_state: fault_tolerance_state(),
  global_model_state: model_state(),
  inner_optimizer_state: optimizer_state(),
  local_model_state: model_state(),
  object_id: String.t(),
  optimizer_config: optimizer_config(),
  outer_optimizer_state: optimizer_state(),
  performance_metrics: performance_metrics(),
  step_counters: step_counters(),
  synchronization_barrier: synchronization_barrier(),
  training_config: training_config(),
  worker_id: worker_id()
}

tensor()

@type tensor() :: %{
  shape: [integer()],
  data: binary(),
  dtype: :float32 | :float16 | :bfloat16
}

training_config()

@type training_config() :: %{
  inner_steps: integer(),
  outer_steps: integer(),
  batch_size: integer(),
  gradient_clipping: float(),
  communication_frequency: integer(),
  fault_tolerance_threshold: float(),
  checkpoint_frequency: integer()
}

vote()

@type vote() :: %{
  step: integer(),
  model_hash: binary(),
  timestamp: DateTime.t(),
  signature: binary()
}

worker_id()

@type worker_id() :: String.t()

Functions

child_spec(init_arg)

Returns a specification to start this module under a supervisor.

See Supervisor.

create_training_coalition(workers, coalition_config)

Initializes a training coalition of distributed workers.

get_metrics(pid)

Gets current training metrics.

inner_steps(pid, data_batch)

Performs H inner training steps on local data.

load_checkpoint(pid, checkpoint_path)

Loads a training checkpoint.

new(opts \\ [])

Creates a new DiLoCo distributed training object.

outer_step(pid)

Performs a single outer training step (T outer iterations).

save_checkpoint(pid, checkpoint_path)

Saves a training checkpoint.

start_link(distributed_trainer, opts \\ [])

Starts the distributed training object as a GenServer.

synchronize(pid)

Synchronizes with other workers in the coalition.

train(pid, training_data)

Executes the DiLoCo training algorithm.