MoE v2: Expert Choice Routing + Shared Experts + Aux-Loss-Free Load Balancing.
Implements three key improvements to the Mixture of Experts architecture: expert choice routing (Zhou et al., 2022), shared expert slots (DeepSeekMoE), and aux-loss-free load balancing via trainable bias (DeepSeek-V3).
Key Innovations
1. Expert Choice Routing
Standard MoE: each token picks its top-K experts. Expert Choice: each expert picks its top-C tokens.
Standard: token -> selects experts (load imbalance risk)
Expert: expert -> selects tokens (perfect load balance)This eliminates the need for load balancing auxiliary loss since each expert processes exactly C tokens by construction.
2. Shared Expert Slots
Some experts are "shared" (always active for every token), while others are "routed" (selected by expert choice). This ensures a base level of computation for every token while allowing specialization:
Output = SharedExperts(x) + RoutedExperts(x)3. Aux-Loss-Free Load Balancing (DeepSeek-V3)
Traditional MoE uses auxiliary loss to encourage uniform expert utilization. However, this creates a trade-off: higher aux loss weight improves balance but hurts model quality.
DeepSeek-V3 introduces a bias-based approach that achieves load balance without auxiliary loss:
- Add a trainable bias term
b[i]to each expert's routing score - Router computes:
scores = gate_logits + bias, then selects top-K - After each forward pass, update bias based on expert utilization:
bias[i] -= lr * (utilization[i] - target_utilization)wheretarget_utilization = 1/num_experts
This pushes the model toward uniform utilization without interfering with the main training objective. When an expert is overused, its bias decreases, making it less likely to be selected. When underused, bias increases.
Architecture
Input [batch, seq_len, input_size]
|
v
+----------------------------------+
| Router (transposed scores) |
| score = softmax(W_r * x + b)^T |
| Each expert picks top-C tokens |
+----------------------------------+
|
+-----+-----+
| |
v v
Shared Routed
Experts Experts
(always (expert
active) choice)
| |
+-----+-----+
|
v
Output = shared + routedUsage
# Standard with auxiliary loss (default)
model = MoEv2.build(
input_size: 256,
hidden_size: 512,
num_shared_experts: 1,
num_routed_experts: 4,
tokens_per_expert: 4,
load_balance: :aux_loss
)
# DeepSeek-V3 style: aux-loss-free with bias
model = MoEv2.build(
input_size: 256,
load_balance: :bias
)
# After training step, update bias for :bias mode
utilization = MoEv2.compute_utilization(router_logits, tokens_per_expert)
params = MoEv2.update_load_balance_bias(params, utilization, lr: 0.001)References
- Zhou et al., "Mixture-of-Experts with Expert Choice Routing" (NeurIPS 2022)
- DeepSeek-AI, "DeepSeekMoE: Towards Ultimate Expert Specialization" (2024)
- DeepSeek-AI, "DeepSeek-V3 Technical Report" (2024) — aux-loss-free load balancing
Summary
Functions
Build a MoE v2 layer with expert choice routing and shared experts.
Compute load balancing auxiliary loss.
Compute expert utilization from router logits.
Get recommended defaults.
Update the load-balance bias based on expert utilization.
Types
@type build_opt() :: {:input_size, pos_integer()} | {:hidden_size, pos_integer()} | {:output_size, pos_integer()} | {:num_shared_experts, pos_integer()} | {:num_routed_experts, pos_integer()} | {:tokens_per_expert, pos_integer()} | {:dropout, float()} | {:expert_type, :ffn | :glu} | {:load_balance, :aux_loss | :bias | :none} | {:load_balance_weight, float()}
Options for build/1.
Functions
Build a MoE v2 layer with expert choice routing and shared experts.
Options
:input_size- Input dimension (required):hidden_size- Expert hidden dimension (default: input_size * 4):output_size- Output dimension (default: input_size):num_shared_experts- Always-active experts (default: 1):num_routed_experts- Expert-choice routed experts (default: 4):tokens_per_expert- Tokens each routed expert selects (default: 4):dropout- Dropout rate (default: 0.1):expert_type-:ffnor:glu(default: :ffn):load_balance- Load balancing strategy (default::aux_loss):aux_loss— traditional auxiliary loss to encourage uniform expert utilization. Usecompute_aux_loss/3to compute the loss and add it to your training objective.:bias— aux-loss-free bias term (DeepSeek-V3 approach). A trainable bias is added to router logits before softmax. At training time, useupdate_load_balance_bias/3to adjust the bias based on expert utilization.:none— no load balancing (standard expert choice routing provides natural balance)
:load_balance_weight- Weight for auxiliary loss when using:aux_loss(default: 0.01)
Returns
An Axon model for the MoE v2 layer.
@spec compute_aux_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute load balancing auxiliary loss.
This loss encourages uniform expert utilization, preventing "expert collapse"
where only a few experts are used. Used when load_balance: :aux_loss.
Formula
aux_loss = alpha * num_experts * sum(f_i * P_i)Where:
- f_i = fraction of tokens routed to expert i
- P_i = average router probability for expert i
- alpha = load_balance_weight
A balanced router has aux_loss approximately 1.0.
Parameters
router_probs- Router softmax probabilities[batch, seq_len, num_experts]expert_mask- Binary mask of selected experts[batch, seq_len, num_experts]opts- Options::load_balance_weight- Auxiliary loss weight (default: 0.01)
Returns
Scalar auxiliary loss tensor.
@spec compute_utilization(Nx.Tensor.t(), pos_integer()) :: Nx.Tensor.t()
Compute expert utilization from router logits.
Returns a tensor of shape [num_experts] representing the fraction of
tokens assigned to each expert. Useful for monitoring load balance and
for update_load_balance_bias/3.
Parameters
router_logits- Router output tensor[batch, seq_len, num_experts]tokens_per_expert- Number of tokens each expert selects
Returns
Tensor of shape [num_experts] with utilization ratios in [0, 1].
@spec recommended_defaults() :: keyword()
Get recommended defaults.
@spec update_load_balance_bias(map(), Nx.Tensor.t(), keyword()) :: map()
Update the load-balance bias based on expert utilization.
Adjusts the bias to route more tokens toward underutilized experts
and fewer toward overutilized ones. Call this after each training step
when using load_balance: :bias.
Parameters
params- Model parameters (fromAxon.build)utilization- Expert utilization tensor fromcompute_utilization/2opts- Options::lr- Bias learning rate (default: 0.001):bias_key- Parameter key for the bias (default: "moe_v2_load_balance_bias")
Returns
Updated model parameters.