Soft Mixture of Experts (Puigcerver et al., 2024).
Unlike hard-routing MoE (Switch/top-K), Soft MoE computes a soft weighted combination of all expert outputs for every token. This eliminates token dropping, load balancing issues, and routing instability while maintaining the capacity benefits of MoE.
Architecture
Input [batch, seq_len, embed_dim]
|
v
+------------------------------------+
| Input Projection |
+------------------------------------+
|
v
+------------------------------------+
| SoftMoE Block: |
| 1. Compute dispatch weights |
| D = softmax(X * Phi) |
| 2. Compute expert inputs |
| X_e = D^T * X |
| 3. Run all experts |
| Y_e = Expert_e(X_e) |
| 4. Combine outputs |
| Y = C * stack(Y_e) |
| + Residual |
+------------------------------------+
| (repeat N times)
v
Output [batch, hidden_size]Usage
model = SoftMoE.build(
embed_dim: 256,
hidden_size: 256,
num_experts: 4,
num_layers: 4
)References
- Puigcerver et al., "From Sparse to Soft Mixtures of Experts" (ICLR 2024)
- https://arxiv.org/abs/2308.00951
Summary
Functions
Build a Soft MoE model.
Get the output size of a Soft MoE model.
Single Soft MoE block with dispatch-combine routing.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_experts, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Soft MoE model.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Hidden dimension (default: 256):num_experts- Number of experts (default: 4):num_layers- Number of SoftMoE blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Sequence length (default: 60)
Returns
An Axon model: [batch, seq_len, embed_dim] -> [batch, hidden_size]
@spec output_size(keyword()) :: pos_integer()
Get the output size of a Soft MoE model.
@spec soft_moe_block(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Single Soft MoE block with dispatch-combine routing.
Options
:num_experts- Number of experts (default: 4):dropout- Dropout rate (default: 0.1):name- Layer name prefix