Vision Transformer (ViT) implementation.
Treats an image as a sequence of fixed-size patches, linearly embeds each patch, prepends a learnable CLS token, adds position embeddings, and processes the resulting sequence through standard transformer encoder blocks.
Architecture
Image [batch, channels, height, width]
|
+-----v-----------------+
| Patch Embedding | Split into P x P patches, linear project
+------------------------+
|
v
[batch, num_patches, embed_dim]
|
+-----v-----------------+
| Prepend CLS Token | Learnable [1, 1, embed_dim] token
+------------------------+
|
v
[batch, num_patches + 1, embed_dim]
|
+-----v-----------------+
| Add Position Embedding | Learnable [1, num_patches + 1, embed_dim]
+------------------------+
|
v
+-----v-----------------+
| Transformer Block x N |
| LayerNorm |
| Self-Attention |
| Residual |
| LayerNorm |
| MLP (expand + GELU) |
| Residual |
+------------------------+
|
v
+-----v-----------------+
| Extract CLS Token | [batch, embed_dim]
+------------------------+
|
v
+-----v-----------------+
| LayerNorm |
+------------------------+
|
v
+-----v-----------------+
| Optional Classifier | Dense -> num_classes
+------------------------+Usage
# ViT-Base for ImageNet
model = ViT.build(
image_size: 224,
patch_size: 16,
embed_dim: 768,
depth: 12,
num_heads: 12
)
# Small ViT for CIFAR-10
model = ViT.build(
image_size: 32,
patch_size: 4,
embed_dim: 192,
depth: 6,
num_heads: 3,
num_classes: 10
)References
- "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., ICLR 2021)
Summary
Types
@type build_opt() :: {:depth, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:image_size, pos_integer()} | {:in_channels, pos_integer()} | {:mlp_ratio, float()} | {:num_classes, pos_integer() | nil} | {:num_heads, pos_integer()} | {:patch_size, pos_integer()}
Options for build/1.
Functions
Build a Vision Transformer model.
Options
:image_size- Input image size, square (default: 224):patch_size- Patch size, square (default: 16):in_channels- Number of input channels (default: 3):embed_dim- Embedding dimension (default: 768):depth- Number of transformer blocks (default: 12):num_heads- Number of attention heads (default: 12):mlp_ratio- MLP hidden dim ratio relative to embed_dim (default: 4.0):dropout- Dropout rate (default: 0.0):num_classes- Number of classes for classification head (optional)
Returns
An Axon model. Without :num_classes, outputs [batch, embed_dim].
With :num_classes, outputs [batch, num_classes].
@spec output_size(keyword()) :: pos_integer()
Get the output size of a ViT model.
Returns :num_classes if set, otherwise :embed_dim.