Switch Transformer - Top-1 Expert Routing.
The Switch Transformer simplifies MoE routing by selecting only a single expert per token (top-1), reducing computation and communication costs while maintaining model capacity. Each token is routed to exactly one expert based on learned routing weights.
Architecture
Input [batch, seq_len, embed_dim]
|
v
+------------------------------------+
| Input Projection |
+------------------------------------+
|
v
+------------------------------------+
| Switch Block 1: |
| Pre-Norm -> Router (top-1) |
| -> Selected Expert FFN |
| + Residual |
+------------------------------------+
| (repeat N times)
v
+------------------------------------+
| Final Norm + Last Timestep |
+------------------------------------+
|
v
Output [batch, hidden_size]Router Design
The router computes softmax probabilities over experts and selects the highest-scoring expert for each token. Since Axon uses static graphs, all experts are computed and the router selects via weighted combination with a peaked (near-one-hot) distribution.
Usage
model = SwitchMoE.build(
embed_dim: 256,
hidden_size: 256,
num_experts: 8,
expert_size: 512,
num_layers: 4
)References
- Fedus et al., "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity" (JMLR 2022)
- https://arxiv.org/abs/2101.03961
Summary
Functions
Build a Switch Transformer model.
Get the output size of a Switch MoE model.
Single Switch block: pre-norm -> top-1 routed expert FFN -> residual.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:expert_size, pos_integer()} | {:hidden_size, pos_integer()} | {:num_experts, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Switch Transformer model.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Hidden dimension (default: 256):num_experts- Number of expert FFNs (default: 8):expert_size- Inner dimension of expert FFNs (default: 512):num_layers- Number of Switch blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Sequence length (default: 60)
Returns
An Axon model: [batch, seq_len, embed_dim] -> [batch, hidden_size]
@spec output_size(keyword()) :: pos_integer()
Get the output size of a Switch MoE model.
@spec switch_block(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Single Switch block: pre-norm -> top-1 routed expert FFN -> residual.