MLP-Mixer - All-MLP architecture for vision.
Replaces attention and convolutions entirely with MLPs. Uses two types of MLP layers applied alternately: token-mixing MLPs that operate across spatial locations (patches), and channel-mixing MLPs that operate within each location independently.
Architecture
Image [batch, channels, height, width]
|
+-----v--------------------+
| Patch Embedding | Split into P x P patches, linear project
+---------------------------+
|
v
[batch, num_patches, hidden_size]
|
+-----v--------------------+
| Mixer Layer x N |
| |
| Token Mixing: |
| LN -> Transpose |
| -> Dense(token_mlp_dim) |
| -> GELU |
| -> Dense(num_patches) |
| -> Transpose |
| + Residual |
| |
| Channel Mixing: |
| LN -> Dense(ch_mlp_dim) |
| -> GELU |
| -> Dense(hidden_size) |
| + Residual |
+---------------------------+
|
v
+-----v--------------------+
| LayerNorm |
+---------------------------+
|
+-----v--------------------+
| Global Average Pool | Mean over patches
+---------------------------+
|
v
[batch, hidden_size]
|
+-----v--------------------+
| Optional Classifier |
+---------------------------+Key Insight
Token-mixing MLPs allow communication between different spatial locations, while channel-mixing MLPs process features within each location. This separation is analogous to depthwise separable convolutions but uses fully-connected layers, achieving competitive results without attention.
Usage
# MLP-Mixer-B/16
model = MLPMixer.build(
image_size: 224,
patch_size: 16,
hidden_size: 768,
num_layers: 12,
token_mlp_dim: 384,
channel_mlp_dim: 3072,
num_classes: 1000
)
# Small Mixer for CIFAR-10
model = MLPMixer.build(
image_size: 32,
patch_size: 4,
hidden_size: 256,
num_layers: 8,
token_mlp_dim: 128,
channel_mlp_dim: 1024,
num_classes: 10
)References
- "MLP-Mixer: An all-MLP Architecture for Vision" (Tolstikhin et al., NeurIPS 2021)
Summary
Types
@type build_opt() :: {:channel_mlp_dim, pos_integer()} | {:dropout, float()} | {:hidden_size, pos_integer()} | {:image_size, pos_integer()} | {:in_channels, pos_integer()} | {:num_classes, pos_integer() | nil} | {:num_layers, pos_integer()} | {:patch_size, pos_integer()} | {:token_mlp_dim, pos_integer()}
Options for build/1.
Functions
Build an MLP-Mixer model.
Options
:image_size- Input image size, square (default: 224):patch_size- Patch size, square (default: 16):in_channels- Number of input channels (default: 3):hidden_size- Hidden dimension per patch (default: 512):num_layers- Number of mixer layers (default: 8):token_mlp_dim- Token-mixing MLP hidden dimension (default: 256):channel_mlp_dim- Channel-mixing MLP hidden dimension (default: 2048):dropout- Dropout rate (default: 0.0):num_classes- Number of classes for classification head (optional)
Returns
An Axon model. Without :num_classes, outputs [batch, hidden_size].
With :num_classes, outputs [batch, num_classes].
@spec output_size(keyword()) :: pos_integer()
Get the output size of an MLP-Mixer model.
Returns :num_classes if set, otherwise :hidden_size.