Edifice.Vision.ViT (Edifice v0.2.0)

Copy Markdown View Source

Vision Transformer (ViT) implementation.

Treats an image as a sequence of fixed-size patches, linearly embeds each patch, prepends a learnable CLS token, adds position embeddings, and processes the resulting sequence through standard transformer encoder blocks.

Architecture

Image [batch, channels, height, width]
      |
+-----v-----------------+
| Patch Embedding        |  Split into P x P patches, linear project
+------------------------+
      |
      v
[batch, num_patches, embed_dim]
      |
+-----v-----------------+
| Prepend CLS Token      |  Learnable [1, 1, embed_dim] token
+------------------------+
      |
      v
[batch, num_patches + 1, embed_dim]
      |
+-----v-----------------+
| Add Position Embedding |  Learnable [1, num_patches + 1, embed_dim]
+------------------------+
      |
      v
+-----v-----------------+
| Transformer Block x N  |
|   LayerNorm            |
|   Self-Attention       |
|   Residual             |
|   LayerNorm            |
|   MLP (expand + GELU)  |
|   Residual             |
+------------------------+
      |
      v
+-----v-----------------+
| Extract CLS Token      |  [batch, embed_dim]
+------------------------+
      |
      v
+-----v-----------------+
| LayerNorm              |
+------------------------+
      |
      v
+-----v-----------------+
| Optional Classifier    |  Dense -> num_classes
+------------------------+

Usage

# ViT-Base for ImageNet
model = ViT.build(
  image_size: 224,
  patch_size: 16,
  embed_dim: 768,
  depth: 12,
  num_heads: 12
)

# Small ViT for CIFAR-10
model = ViT.build(
  image_size: 32,
  patch_size: 4,
  embed_dim: 192,
  depth: 6,
  num_heads: 3,
  num_classes: 10
)

References

  • "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., ICLR 2021)

Summary

Types

Options for build/1.

Functions

Build a Vision Transformer model.

Get the output size of a ViT model.

Types

build_opt()

@type build_opt() ::
  {:depth, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, pos_integer()}
  | {:patch_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Vision Transformer model.

Options

  • :image_size - Input image size, square (default: 224)
  • :patch_size - Patch size, square (default: 16)
  • :in_channels - Number of input channels (default: 3)
  • :embed_dim - Embedding dimension (default: 768)
  • :depth - Number of transformer blocks (default: 12)
  • :num_heads - Number of attention heads (default: 12)
  • :mlp_ratio - MLP hidden dim ratio relative to embed_dim (default: 4.0)
  • :dropout - Dropout rate (default: 0.0)
  • :num_classes - Number of classes for classification head (optional)

Returns

An Axon model. Without :num_classes, outputs [batch, embed_dim]. With :num_classes, outputs [batch, num_classes].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a ViT model.

Returns :num_classes if set, otherwise :embed_dim.