Edifice.Meta.SwitchMoE (Edifice v0.2.0)

Copy Markdown View Source

Switch Transformer - Top-1 Expert Routing.

The Switch Transformer simplifies MoE routing by selecting only a single expert per token (top-1), reducing computation and communication costs while maintaining model capacity. Each token is routed to exactly one expert based on learned routing weights.

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+------------------------------------+
| Input Projection                   |
+------------------------------------+
      |
      v
+------------------------------------+
| Switch Block 1:                    |
|   Pre-Norm -> Router (top-1)       |
|   -> Selected Expert FFN           |
|   + Residual                       |
+------------------------------------+
      |  (repeat N times)
      v
+------------------------------------+
| Final Norm + Last Timestep         |
+------------------------------------+
      |
      v
Output [batch, hidden_size]

Router Design

The router computes softmax probabilities over experts and selects the highest-scoring expert for each token. Since Axon uses static graphs, all experts are computed and the router selects via weighted combination with a peaked (near-one-hot) distribution.

Usage

model = SwitchMoE.build(
  embed_dim: 256,
  hidden_size: 256,
  num_experts: 8,
  expert_size: 512,
  num_layers: 4
)

References

Summary

Types

Options for build/1.

Functions

Build a Switch Transformer model.

Get the output size of a Switch MoE model.

Single Switch block: pre-norm -> top-1 routed expert FFN -> residual.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expert_size, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_experts, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Switch Transformer model.

Options

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Hidden dimension (default: 256)
  • :num_experts - Number of expert FFNs (default: 8)
  • :expert_size - Inner dimension of expert FFNs (default: 512)
  • :num_layers - Number of Switch blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Sequence length (default: 60)

Returns

An Axon model: [batch, seq_len, embed_dim] -> [batch, hidden_size]

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a Switch MoE model.

switch_block(input, hidden_size, opts \\ [])

@spec switch_block(Axon.t(), pos_integer(), keyword()) :: Axon.t()

Single Switch block: pre-norm -> top-1 routed expert FFN -> residual.