EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction.
Implements EfficientViT from "EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction" (Liu et al., 2023). Achieves O(n) complexity instead of O(n²) via linear attention with cascaded group attention.
Key Innovations
- Linear attention: Uses kernel trick to avoid materializing the full attention matrix. Q×K^T is computed via feature maps, giving O(n) complexity.
- Cascaded group attention (CGA): Different heads see different channel splits of the input, enforcing head diversity and reducing redundancy.
- Multi-scale: Progressive downsampling stages, each with its own dimension.
- Depthwise conv in FFN: Adds local context between linear layers.
Architecture
Image [batch, channels, height, width]
|
v
+--------------------------+
| Patch Embedding |
+--------------------------+
|
v
+==========================+
| Stage 1 (depth[0] blocks) |
| CGA Linear Attention |
| DW-Conv FFN |
+==========================+
| (downsample)
v
+==========================+
| Stage 2 (depth[1] blocks) |
| CGA Linear Attention |
| DW-Conv FFN |
+==========================+
| (downsample)
v
+==========================+
| Stage 3 (depth[2] blocks) |
| CGA Linear Attention |
| DW-Conv FFN |
+==========================+
|
v
+--------------------------+
| LayerNorm + Global Pool |
+--------------------------+
|
v
[batch, last_dim]Cascaded Group Attention
Input: [batch, seq, dim]
|
Split into num_heads groups along dim
|
Head 0: [batch, seq, dim/heads] → Q₀, K₀, V₀ → LinearAttn → out₀
Head 1: [batch, seq, dim/heads] → Q₁, K₁, V₁ → LinearAttn → out₁ + out₀
Head 2: [batch, seq, dim/heads] → Q₂, K₂, V₂ → LinearAttn → out₂ + out₁
...
|
Concatenate all head outputs
|
Output projectionEach head sees a unique slice of the feature map (no shared representation), which forces diverse attention patterns across heads.
Linear Attention
Standard attention: O(n²)
Attn = softmax(QK^T/√d) × VLinear attention: O(n)
Attn = φ(Q) × (φ(K)^T × V) where φ is ELU+1By computing φ(K)^T × V first (d×d matrix), we avoid the n×n attention matrix entirely.
Usage
model = EfficientViT.build(
image_size: 224,
patch_size: 16,
embed_dim: 64,
depths: [1, 2, 3],
num_heads: [4, 4, 4]
)References
- Paper: "EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction"
- arXiv: https://arxiv.org/abs/2205.14756
Summary
Functions
Build an EfficientViT model with linear attention.
Get the output size of an EfficientViT model.
Types
@type build_opt() :: {:depths, [pos_integer()]} | {:embed_dim, pos_integer()} | {:image_size, pos_integer()} | {:in_channels, pos_integer()} | {:mlp_ratio, float()} | {:num_classes, pos_integer() | nil} | {:num_heads, [pos_integer()]} | {:patch_size, pos_integer()}
Options for build/1.
Functions
Build an EfficientViT model with linear attention.
Options
:image_size- Input image size, square (default: 224):patch_size- Patch size, square (default: 16):in_channels- Number of input channels (default: 3):embed_dim- Initial embedding dimension (default: 64):depths- Number of blocks per stage (default: [1, 2, 3]):num_heads- Number of attention heads per stage (default: [4, 4, 4]):mlp_ratio- MLP expansion ratio (default: 4.0):num_classes- Number of output classes (optional)
Returns
An Axon model. Without :num_classes, outputs [batch, last_dim].
With :num_classes, outputs [batch, num_classes].
@spec output_size(keyword()) :: pos_integer()
Get the output size of an EfficientViT model.