Mixture of Experts (MoE) for adaptive expert selection.
Overview
MoE routes each input to a subset of specialized "expert" networks based on a learned routing function. This allows the model to have much larger capacity while maintaining fast inference (only K experts are active per input).
Input x
|
v
+-----------------+
| Router | -> Selects top-K experts
| (softmax gate) |
+--------+--------+
|
+------+------+------+------+
v v v v v
+---+ +---+ +---+ +---+ +---+
|E1 | |E2 | |E3 | |E4 | |E5 | (Experts)
+-+-+ +-+-+ +---+ +---+ +---+
| | (inactive)
v v
weighted sum
|
v
OutputExpert Specialization
Different experts can specialize on different input patterns:
- Expert 1: Common patterns (frequent states)
- Expert 2: Transition states (changes between modes)
- Expert 3: Edge cases (rare but important situations)
- Expert 4: Fine-grained distinctions (subtle differences)
Routing Strategies
| Strategy | Description | Load Balance |
|---|---|---|
:top_k | Select K highest-scoring experts | Requires aux loss |
:switch | Route to single best expert | Best balance |
:soft | Weighted sum of all experts | Most expensive |
:hash | Deterministic based on input hash | Perfect balance |
Usage
# Create MoE layer with 8 experts, top-2 routing
moe = MoE.build(
input_size: 256,
hidden_size: 512,
num_experts: 8,
top_k: 2,
routing: :top_k
)
# With load balancing loss
{output, aux_loss} = MoE.forward_with_aux(moe, input, params)
Summary
Functions
Build a Mixture of Experts layer.
Build a complete MoE block with pre-norm and residual.
Build an MoE-enhanced backbone by replacing FFN layers with MoE.
Compute load balancing auxiliary loss.
Calculate theoretical speedup from MoE.
Get recommended MoE configuration.
Types
@type build_opt() :: {:dropout, float()} | {:expert_type, :ffn | :glu | :mamba} | {:hidden_size, pos_integer()} | {:input_size, pos_integer()} | {:num_experts, pos_integer()} | {:output_size, pos_integer()} | {:routing, routing_strategy()} | {:top_k, pos_integer()}
Options for build/1.
@type routing_strategy() :: :top_k | :switch | :soft | :hash
Functions
Build a Mixture of Experts layer.
Options
Architecture:
:input_size- Input dimension (required):hidden_size- Expert hidden dimension (default: input_size * 4):output_size- Output dimension (default: input_size):num_experts- Number of expert networks (default: 8):top_k- Number of experts per input (default: 2):routing- Routing strategy (default: :top_k)
Regularization:
:dropout- Dropout rate (default: 0.1):capacity_factor- Max tokens per expert multiplier (default: 1.25):load_balance_weight- Auxiliary loss weight (default: 0.01)
Expert architecture:
:expert_type-:ffn,:glu, or:mamba(default: :ffn)
Returns
An Axon model for the MoE layer.
Build a complete MoE block with pre-norm and residual.
This wraps the MoE layer with the standard transformer block pattern.
Build an MoE-enhanced backbone by replacing FFN layers with MoE.
Takes an existing backbone configuration and converts FFN sublayers to MoE.
Options
:backbone- Base backbone (:mamba, :attention, etc.):moe_every- Apply MoE every N layers (default: 2):num_experts- Experts per MoE layer (default: 8):top_k- Active experts per input (default: 2)
@spec compute_aux_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute load balancing auxiliary loss.
This loss encourages uniform expert utilization, preventing "expert collapse" where only a few experts are used.
Formula
aux_loss = alpha * num_experts * sum(f_i * P_i)Where:
- f_i = fraction of tokens routed to expert i
- P_i = average router probability for expert i
- alpha = load_balance_weight
A balanced router has aux_loss approximately 1.0.
@spec estimate_speedup(pos_integer(), pos_integer(), float()) :: float()
Calculate theoretical speedup from MoE.
Arguments
num_experts- Total number of expertstop_k- Active experts per inputexpert_fraction- Fraction of model that is expert layers
Returns
Approximate FLOPs reduction ratio.
@spec recommended_defaults() :: keyword()
Get recommended MoE configuration.