Swin Transformer (Shifted Window Transformer) implementation.
A hierarchical vision transformer that computes attention within local windows and shifts windows between layers for cross-window connections. Produces multi-scale feature maps like a CNN, making it suitable for dense prediction tasks.
Architecture
Image [batch, channels, height, width]
|
+-----v--------------------+
| Patch Embedding | patch_size x patch_size, linear project
+---------------------------+
|
v
[batch, H/4 * W/4, embed_dim]
|
+-----v--------------------+
| Stage 1 | depths[0] Swin blocks at embed_dim
| Window Attention | Alternating regular/shifted windows
+---------------------------+
|
+-----v--------------------+
| Patch Merging | 2x2 spatial merge, 2x channel expand
+---------------------------+
|
+-----v--------------------+
| Stage 2 | depths[1] blocks at embed_dim * 2
+---------------------------+
|
+-----v--------------------+
| Patch Merging |
+---------------------------+
|
+-----v--------------------+
| Stage 3 | depths[2] blocks at embed_dim * 4
+---------------------------+
|
+-----v--------------------+
| Patch Merging |
+---------------------------+
|
+-----v--------------------+
| Stage 4 | depths[3] blocks at embed_dim * 8
+---------------------------+
|
+-----v--------------------+
| Global Average Pooling |
+---------------------------+
|
+-----v--------------------+
| LayerNorm |
+---------------------------+
|
+-----v--------------------+
| Optional Classifier |
+---------------------------+Window Attention
Attention is computed within non-overlapping local windows of M x M tokens, reducing complexity from O(N^2) to O(N * M^2). Shifted windows in alternating layers enable cross-window information flow via cyclic shift and masked attention.
Features:
- Real window partitioning with M x M local attention
- Multi-head scaled dot-product attention within each window
- Cyclic shift for shifted window attention with boundary masking
- Learnable relative position bias per attention head
Usage
# Swin-Tiny
model = SwinTransformer.build(
image_size: 224,
patch_size: 4,
embed_dim: 96,
depths: [2, 2, 6, 2],
num_heads: [3, 6, 12, 24],
num_classes: 1000
)References
- "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (Liu et al., ICCV 2021)
Summary
Functions
Build a Swin Transformer model.
Compute relative position bias for window attention.
Compute attention mask for shifted windows.
Cyclic shift: roll tensor by -shift_size along both H and W axes.
Get the output size of a Swin Transformer model.
Reverse cyclic shift: roll tensor by +shift_size along both H and W axes.
Partition a spatial tensor into non-overlapping windows.
Reverse window partition back to spatial layout.
Types
@type build_opt() :: {:depths, [pos_integer()]} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:image_size, pos_integer()} | {:in_channels, pos_integer()} | {:mlp_ratio, float()} | {:num_classes, pos_integer() | nil} | {:num_heads, pos_integer()} | {:patch_size, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Swin Transformer model.
Options
:image_size- Input image size, square (default: 224):patch_size- Initial patch embedding size (default: 4):in_channels- Number of input channels (default: 3):embed_dim- Base embedding dimension (default: 96):depths- Number of blocks per stage (default: [2, 2, 6, 2]):num_heads- Number of attention heads per stage (default: [3, 6, 12, 24]):window_size- Window size for local attention (default: 7):mlp_ratio- MLP hidden dim ratio (default: 4.0):dropout- Dropout rate (default: 0.0):num_classes- Number of classes for classification head (optional)
Spatial dimensions at each stage must be divisible by the effective window size.
Returns
An Axon model. Without :num_classes, outputs [batch, embed_dim * 2^(num_stages-1)].
With :num_classes, outputs [batch, num_classes].
@spec compute_relative_position_bias(pos_integer(), pos_integer()) :: Nx.Tensor.t()
Compute relative position bias for window attention.
Uses distance-based decay with per-head geometric slopes, similar to ALiBi but for 2D windows. Each head gets a different slope, providing diverse position sensitivity across heads.
Returns a [1, num_heads, wsws, wsws] bias tensor.
@spec compute_shift_mask(pos_integer(), pos_integer(), pos_integer(), pos_integer()) :: Nx.Tensor.t()
Compute attention mask for shifted windows.
Assigns region IDs based on position relative to shift boundaries, then creates a pairwise mask that blocks attention between tokens from different regions within each window.
Returns a [num_windows, wsws, wsws] mask tensor with 0.0 for allowed attention and -100.0 for blocked attention.
@spec cyclic_shift(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) :: Nx.Tensor.t()
Cyclic shift: roll tensor by -shift_size along both H and W axes.
@spec output_size(keyword()) :: pos_integer()
Get the output size of a Swin Transformer model.
Returns :num_classes if set, otherwise embed_dim * 2^(num_stages - 1).
@spec reverse_cyclic_shift(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) :: Nx.Tensor.t()
Reverse cyclic shift: roll tensor by +shift_size along both H and W axes.
@spec window_partition(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) :: Nx.Tensor.t()
Partition a spatial tensor into non-overlapping windows.
Input: [B, H, W, C] -> Output: [BnW, wsws, C]
@spec window_reverse( Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer(), pos_integer() ) :: Nx.Tensor.t()
Reverse window partition back to spatial layout.
Input: [BnW, wsws, C] -> Output: [B, H, W, C]