Edifice.Vision.EfficientViT (Edifice v0.2.0)

Copy Markdown View Source

EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction.

Implements EfficientViT from "EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction" (Liu et al., 2023). Achieves O(n) complexity instead of O(n²) via linear attention with cascaded group attention.

Key Innovations

  • Linear attention: Uses kernel trick to avoid materializing the full attention matrix. Q×K^T is computed via feature maps, giving O(n) complexity.
  • Cascaded group attention (CGA): Different heads see different channel splits of the input, enforcing head diversity and reducing redundancy.
  • Multi-scale: Progressive downsampling stages, each with its own dimension.
  • Depthwise conv in FFN: Adds local context between linear layers.

Architecture

Image [batch, channels, height, width]
      |
      v
+--------------------------+
| Patch Embedding           |
+--------------------------+
      |
      v
+==========================+
| Stage 1 (depth[0] blocks) |
|  CGA Linear Attention    |
|  DW-Conv FFN             |
+==========================+
      | (downsample)
      v
+==========================+
| Stage 2 (depth[1] blocks) |
|  CGA Linear Attention    |
|  DW-Conv FFN             |
+==========================+
      | (downsample)
      v
+==========================+
| Stage 3 (depth[2] blocks) |
|  CGA Linear Attention    |
|  DW-Conv FFN             |
+==========================+
      |
      v
+--------------------------+
| LayerNorm + Global Pool  |
+--------------------------+
      |
      v
[batch, last_dim]

Cascaded Group Attention

Input: [batch, seq, dim]
       |
  Split into num_heads groups along dim
       |
  Head 0: [batch, seq, dim/heads]  Q, K, V  LinearAttn  out
  Head 1: [batch, seq, dim/heads]  Q, K, V  LinearAttn  out + out
  Head 2: [batch, seq, dim/heads]  Q, K, V  LinearAttn  out + out
  ...
       |
  Concatenate all head outputs
       |
  Output projection

Each head sees a unique slice of the feature map (no shared representation), which forces diverse attention patterns across heads.

Linear Attention

Standard attention: O(n²)

Attn = softmax(QK^T/d) × V

Linear attention: O(n)

Attn = φ(Q) × (φ(K)^T × V)  where φ is ELU+1

By computing φ(K)^T × V first (d×d matrix), we avoid the n×n attention matrix entirely.

Usage

model = EfficientViT.build(
  image_size: 224,
  patch_size: 16,
  embed_dim: 64,
  depths: [1, 2, 3],
  num_heads: [4, 4, 4]
)

References

Summary

Types

Options for build/1.

Functions

Build an EfficientViT model with linear attention.

Get the output size of an EfficientViT model.

Types

build_opt()

@type build_opt() ::
  {:depths, [pos_integer()]}
  | {:embed_dim, pos_integer()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, [pos_integer()]}
  | {:patch_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an EfficientViT model with linear attention.

Options

  • :image_size - Input image size, square (default: 224)
  • :patch_size - Patch size, square (default: 16)
  • :in_channels - Number of input channels (default: 3)
  • :embed_dim - Initial embedding dimension (default: 64)
  • :depths - Number of blocks per stage (default: [1, 2, 3])
  • :num_heads - Number of attention heads per stage (default: [4, 4, 4])
  • :mlp_ratio - MLP expansion ratio (default: 4.0)
  • :num_classes - Number of output classes (optional)

Returns

An Axon model. Without :num_classes, outputs [batch, last_dim]. With :num_classes, outputs [batch, num_classes].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of an EfficientViT model.