Edifice.Meta.MoEv2 (Edifice v0.2.0)

Copy Markdown View Source

MoE v2: Expert Choice Routing + Shared Experts + Aux-Loss-Free Load Balancing.

Implements three key improvements to the Mixture of Experts architecture: expert choice routing (Zhou et al., 2022), shared expert slots (DeepSeekMoE), and aux-loss-free load balancing via trainable bias (DeepSeek-V3).

Key Innovations

1. Expert Choice Routing

Standard MoE: each token picks its top-K experts. Expert Choice: each expert picks its top-C tokens.

Standard:  token -> selects experts   (load imbalance risk)
Expert:    expert -> selects tokens   (perfect load balance)

This eliminates the need for load balancing auxiliary loss since each expert processes exactly C tokens by construction.

2. Shared Expert Slots

Some experts are "shared" (always active for every token), while others are "routed" (selected by expert choice). This ensures a base level of computation for every token while allowing specialization:

Output = SharedExperts(x) + RoutedExperts(x)

3. Aux-Loss-Free Load Balancing (DeepSeek-V3)

Traditional MoE uses auxiliary loss to encourage uniform expert utilization. However, this creates a trade-off: higher aux loss weight improves balance but hurts model quality.

DeepSeek-V3 introduces a bias-based approach that achieves load balance without auxiliary loss:

  1. Add a trainable bias term b[i] to each expert's routing score
  2. Router computes: scores = gate_logits + bias, then selects top-K
  3. After each forward pass, update bias based on expert utilization: bias[i] -= lr * (utilization[i] - target_utilization) where target_utilization = 1/num_experts

This pushes the model toward uniform utilization without interfering with the main training objective. When an expert is overused, its bias decreases, making it less likely to be selected. When underused, bias increases.

Architecture

Input [batch, seq_len, input_size]
      |
      v
+----------------------------------+
|     Router (transposed scores)   |
|  score = softmax(W_r * x + b)^T |
|  Each expert picks top-C tokens  |
+----------------------------------+
      |
+-----+-----+
|             |
v             v
Shared      Routed
Experts     Experts
(always     (expert
active)     choice)
|             |
+-----+-----+
      |
      v
Output = shared + routed

Usage

# Standard with auxiliary loss (default)
model = MoEv2.build(
  input_size: 256,
  hidden_size: 512,
  num_shared_experts: 1,
  num_routed_experts: 4,
  tokens_per_expert: 4,
  load_balance: :aux_loss
)

# DeepSeek-V3 style: aux-loss-free with bias
model = MoEv2.build(
  input_size: 256,
  load_balance: :bias
)

# After training step, update bias for :bias mode
utilization = MoEv2.compute_utilization(router_logits, tokens_per_expert)
params = MoEv2.update_load_balance_bias(params, utilization, lr: 0.001)

References

  • Zhou et al., "Mixture-of-Experts with Expert Choice Routing" (NeurIPS 2022)
  • DeepSeek-AI, "DeepSeekMoE: Towards Ultimate Expert Specialization" (2024)
  • DeepSeek-AI, "DeepSeek-V3 Technical Report" (2024) — aux-loss-free load balancing

Summary

Types

Options for build/1.

Functions

Build a MoE v2 layer with expert choice routing and shared experts.

Compute load balancing auxiliary loss.

Compute expert utilization from router logits.

Get recommended defaults.

Update the load-balance bias based on expert utilization.

Types

build_opt()

@type build_opt() ::
  {:input_size, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:output_size, pos_integer()}
  | {:num_shared_experts, pos_integer()}
  | {:num_routed_experts, pos_integer()}
  | {:tokens_per_expert, pos_integer()}
  | {:dropout, float()}
  | {:expert_type, :ffn | :glu}
  | {:load_balance, :aux_loss | :bias | :none}
  | {:load_balance_weight, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a MoE v2 layer with expert choice routing and shared experts.

Options

  • :input_size - Input dimension (required)
  • :hidden_size - Expert hidden dimension (default: input_size * 4)
  • :output_size - Output dimension (default: input_size)
  • :num_shared_experts - Always-active experts (default: 1)
  • :num_routed_experts - Expert-choice routed experts (default: 4)
  • :tokens_per_expert - Tokens each routed expert selects (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :expert_type - :ffn or :glu (default: :ffn)
  • :load_balance - Load balancing strategy (default: :aux_loss)
    • :aux_loss — traditional auxiliary loss to encourage uniform expert utilization. Use compute_aux_loss/3 to compute the loss and add it to your training objective.
    • :bias — aux-loss-free bias term (DeepSeek-V3 approach). A trainable bias is added to router logits before softmax. At training time, use update_load_balance_bias/3 to adjust the bias based on expert utilization.
    • :none — no load balancing (standard expert choice routing provides natural balance)
  • :load_balance_weight - Weight for auxiliary loss when using :aux_loss (default: 0.01)

Returns

An Axon model for the MoE v2 layer.

compute_aux_loss(router_probs, expert_mask, opts \\ [])

@spec compute_aux_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute load balancing auxiliary loss.

This loss encourages uniform expert utilization, preventing "expert collapse" where only a few experts are used. Used when load_balance: :aux_loss.

Formula

aux_loss = alpha * num_experts * sum(f_i * P_i)

Where:

  • f_i = fraction of tokens routed to expert i
  • P_i = average router probability for expert i
  • alpha = load_balance_weight

A balanced router has aux_loss approximately 1.0.

Parameters

  • router_probs - Router softmax probabilities [batch, seq_len, num_experts]
  • expert_mask - Binary mask of selected experts [batch, seq_len, num_experts]
  • opts - Options:
    • :load_balance_weight - Auxiliary loss weight (default: 0.01)

Returns

Scalar auxiliary loss tensor.

compute_utilization(router_logits, tokens_per_expert)

@spec compute_utilization(Nx.Tensor.t(), pos_integer()) :: Nx.Tensor.t()

Compute expert utilization from router logits.

Returns a tensor of shape [num_experts] representing the fraction of tokens assigned to each expert. Useful for monitoring load balance and for update_load_balance_bias/3.

Parameters

  • router_logits - Router output tensor [batch, seq_len, num_experts]
  • tokens_per_expert - Number of tokens each expert selects

Returns

Tensor of shape [num_experts] with utilization ratios in [0, 1].

update_load_balance_bias(params, utilization, opts \\ [])

@spec update_load_balance_bias(map(), Nx.Tensor.t(), keyword()) :: map()

Update the load-balance bias based on expert utilization.

Adjusts the bias to route more tokens toward underutilized experts and fewer toward overutilized ones. Call this after each training step when using load_balance: :bias.

Parameters

  • params - Model parameters (from Axon.build)
  • utilization - Expert utilization tensor from compute_utilization/2
  • opts - Options:
    • :lr - Bias learning rate (default: 0.001)
    • :bias_key - Parameter key for the bias (default: "moe_v2_load_balance_bias")

Returns

Updated model parameters.