Edifice.Vision.MLPMixer (Edifice v0.2.0)

Copy Markdown View Source

MLP-Mixer - All-MLP architecture for vision.

Replaces attention and convolutions entirely with MLPs. Uses two types of MLP layers applied alternately: token-mixing MLPs that operate across spatial locations (patches), and channel-mixing MLPs that operate within each location independently.

Architecture

Image [batch, channels, height, width]
      |
+-----v--------------------+
| Patch Embedding           |  Split into P x P patches, linear project
+---------------------------+
      |
      v
[batch, num_patches, hidden_size]
      |
+-----v--------------------+
| Mixer Layer x N           |
|                           |
| Token Mixing:             |
|   LN -> Transpose         |
|   -> Dense(token_mlp_dim) |
|   -> GELU                 |
|   -> Dense(num_patches)   |
|   -> Transpose            |
|   + Residual              |
|                           |
| Channel Mixing:           |
|   LN -> Dense(ch_mlp_dim) |
|   -> GELU                 |
|   -> Dense(hidden_size)   |
|   + Residual              |
+---------------------------+
      |
      v
+-----v--------------------+
| LayerNorm                 |
+---------------------------+
      |
+-----v--------------------+
| Global Average Pool       |  Mean over patches
+---------------------------+
      |
      v
[batch, hidden_size]
      |
+-----v--------------------+
| Optional Classifier       |
+---------------------------+

Key Insight

Token-mixing MLPs allow communication between different spatial locations, while channel-mixing MLPs process features within each location. This separation is analogous to depthwise separable convolutions but uses fully-connected layers, achieving competitive results without attention.

Usage

# MLP-Mixer-B/16
model = MLPMixer.build(
  image_size: 224,
  patch_size: 16,
  hidden_size: 768,
  num_layers: 12,
  token_mlp_dim: 384,
  channel_mlp_dim: 3072,
  num_classes: 1000
)

# Small Mixer for CIFAR-10
model = MLPMixer.build(
  image_size: 32,
  patch_size: 4,
  hidden_size: 256,
  num_layers: 8,
  token_mlp_dim: 128,
  channel_mlp_dim: 1024,
  num_classes: 10
)

References

  • "MLP-Mixer: An all-MLP Architecture for Vision" (Tolstikhin et al., NeurIPS 2021)

Summary

Types

Options for build/1.

Functions

Build an MLP-Mixer model.

Get the output size of an MLP-Mixer model.

Types

build_opt()

@type build_opt() ::
  {:channel_mlp_dim, pos_integer()}
  | {:dropout, float()}
  | {:hidden_size, pos_integer()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:num_classes, pos_integer() | nil}
  | {:num_layers, pos_integer()}
  | {:patch_size, pos_integer()}
  | {:token_mlp_dim, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an MLP-Mixer model.

Options

  • :image_size - Input image size, square (default: 224)
  • :patch_size - Patch size, square (default: 16)
  • :in_channels - Number of input channels (default: 3)
  • :hidden_size - Hidden dimension per patch (default: 512)
  • :num_layers - Number of mixer layers (default: 8)
  • :token_mlp_dim - Token-mixing MLP hidden dimension (default: 256)
  • :channel_mlp_dim - Channel-mixing MLP hidden dimension (default: 2048)
  • :dropout - Dropout rate (default: 0.0)
  • :num_classes - Number of classes for classification head (optional)

Returns

An Axon model. Without :num_classes, outputs [batch, hidden_size]. With :num_classes, outputs [batch, num_classes].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of an MLP-Mixer model.

Returns :num_classes if set, otherwise :hidden_size.