Edifice.Vision.DeiT (Edifice v0.2.0)

Copy Markdown View Source

Data-efficient Image Transformer (DeiT) implementation.

Extends ViT with a distillation token that learns from a teacher model. During training, the CLS token produces the classification output and the distillation token produces the teacher-aligned output. At inference, both token outputs can be averaged for improved accuracy.

Architecture

Image [batch, channels, height, width]
      |
+-----v--------------------+
| Patch Embedding           |  Split into P x P patches, linear project
+---------------------------+
      |
      v
[batch, num_patches, embed_dim]
      |
+-----v--------------------+
| Prepend CLS + Dist Tokens |  Two learnable [1, 1, embed_dim] tokens
+---------------------------+
      |
      v
[batch, num_patches + 2, embed_dim]
      |
+-----v--------------------+
| Add Position Embedding    |  Learnable [1, num_patches + 2, embed_dim]
+---------------------------+
      |
      v
+-----v--------------------+
| Transformer Block x N     |
|   LayerNorm -> Attention  |
|   Residual -> LayerNorm   |
|   MLP -> Residual         |
+---------------------------+
      |
      v
+-----v--------------------+
| Extract CLS (idx 0)       |  -> Classification head
| Extract Dist (idx 1)      |  -> Distillation head (teacher loss)
+---------------------------+

Usage

# DeiT-Base for ImageNet with distillation
model = DeiT.build(
  image_size: 224,
  patch_size: 16,
  embed_dim: 768,
  depth: 12,
  num_heads: 12,
  num_classes: 1000,
  teacher_num_classes: 1000
)

# DeiT-Small without distillation head
model = DeiT.build(
  image_size: 224,
  patch_size: 16,
  embed_dim: 384,
  depth: 12,
  num_heads: 6,
  num_classes: 1000
)

References

  • "Training data-efficient image transformers & distillation through attention" (Touvron et al., ICML 2021)

Summary

Types

Options for build/1.

Functions

Build a DeiT model.

Get the output size of a DeiT model.

Types

build_opt()

@type build_opt() ::
  {:depth, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, pos_integer()}
  | {:patch_size, pos_integer()}
  | {:teacher_num_classes, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a DeiT model.

Options

  • :image_size - Input image size, square (default: 224)
  • :patch_size - Patch size, square (default: 16)
  • :in_channels - Number of input channels (default: 3)
  • :embed_dim - Embedding dimension (default: 768)
  • :depth - Number of transformer blocks (default: 12)
  • :num_heads - Number of attention heads (default: 12)
  • :mlp_ratio - MLP hidden dim ratio relative to embed_dim (default: 4.0)
  • :dropout - Dropout rate (default: 0.0)
  • :num_classes - Number of classes for classification head (optional)
  • :teacher_num_classes - Number of classes for distillation head (optional). If set, model returns {cls_output, dist_output} via a container output.

Returns

An Axon model. When both :num_classes and :teacher_num_classes are set, outputs a container %{cls: [batch, num_classes], dist: [batch, teacher_num_classes]}. When only :num_classes is set, outputs [batch, num_classes]. Otherwise, outputs [batch, embed_dim] from the CLS token.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a DeiT model.

Returns :num_classes if set, otherwise :embed_dim.