Data-efficient Image Transformer (DeiT) implementation.
Extends ViT with a distillation token that learns from a teacher model. During training, the CLS token produces the classification output and the distillation token produces the teacher-aligned output. At inference, both token outputs can be averaged for improved accuracy.
Architecture
Image [batch, channels, height, width]
|
+-----v--------------------+
| Patch Embedding | Split into P x P patches, linear project
+---------------------------+
|
v
[batch, num_patches, embed_dim]
|
+-----v--------------------+
| Prepend CLS + Dist Tokens | Two learnable [1, 1, embed_dim] tokens
+---------------------------+
|
v
[batch, num_patches + 2, embed_dim]
|
+-----v--------------------+
| Add Position Embedding | Learnable [1, num_patches + 2, embed_dim]
+---------------------------+
|
v
+-----v--------------------+
| Transformer Block x N |
| LayerNorm -> Attention |
| Residual -> LayerNorm |
| MLP -> Residual |
+---------------------------+
|
v
+-----v--------------------+
| Extract CLS (idx 0) | -> Classification head
| Extract Dist (idx 1) | -> Distillation head (teacher loss)
+---------------------------+Usage
# DeiT-Base for ImageNet with distillation
model = DeiT.build(
image_size: 224,
patch_size: 16,
embed_dim: 768,
depth: 12,
num_heads: 12,
num_classes: 1000,
teacher_num_classes: 1000
)
# DeiT-Small without distillation head
model = DeiT.build(
image_size: 224,
patch_size: 16,
embed_dim: 384,
depth: 12,
num_heads: 6,
num_classes: 1000
)References
- "Training data-efficient image transformers & distillation through attention" (Touvron et al., ICML 2021)
Summary
Types
@type build_opt() :: {:depth, pos_integer()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:image_size, pos_integer()} | {:in_channels, pos_integer()} | {:mlp_ratio, float()} | {:num_classes, pos_integer() | nil} | {:num_heads, pos_integer()} | {:patch_size, pos_integer()} | {:teacher_num_classes, pos_integer()}
Options for build/1.
Functions
Build a DeiT model.
Options
:image_size- Input image size, square (default: 224):patch_size- Patch size, square (default: 16):in_channels- Number of input channels (default: 3):embed_dim- Embedding dimension (default: 768):depth- Number of transformer blocks (default: 12):num_heads- Number of attention heads (default: 12):mlp_ratio- MLP hidden dim ratio relative to embed_dim (default: 4.0):dropout- Dropout rate (default: 0.0):num_classes- Number of classes for classification head (optional):teacher_num_classes- Number of classes for distillation head (optional). If set, model returns{cls_output, dist_output}via a container output.
Returns
An Axon model. When both :num_classes and :teacher_num_classes are set,
outputs a container %{cls: [batch, num_classes], dist: [batch, teacher_num_classes]}.
When only :num_classes is set, outputs [batch, num_classes].
Otherwise, outputs [batch, embed_dim] from the CLS token.
@spec output_size(keyword()) :: pos_integer()
Get the output size of a DeiT model.
Returns :num_classes if set, otherwise :embed_dim.