Edifice.Vision.SwinTransformer (Edifice v0.2.0)

Copy Markdown View Source

Swin Transformer (Shifted Window Transformer) implementation.

A hierarchical vision transformer that computes attention within local windows and shifts windows between layers for cross-window connections. Produces multi-scale feature maps like a CNN, making it suitable for dense prediction tasks.

Architecture

Image [batch, channels, height, width]
      |
+-----v--------------------+
| Patch Embedding           |  patch_size x patch_size, linear project
+---------------------------+
      |
      v
[batch, H/4 * W/4, embed_dim]
      |
+-----v--------------------+
| Stage 1                   |  depths[0] Swin blocks at embed_dim
|   Window Attention        |  Alternating regular/shifted windows
+---------------------------+
      |
+-----v--------------------+
| Patch Merging             |  2x2 spatial merge, 2x channel expand
+---------------------------+
      |
+-----v--------------------+
| Stage 2                   |  depths[1] blocks at embed_dim * 2
+---------------------------+
      |
+-----v--------------------+
| Patch Merging             |
+---------------------------+
      |
+-----v--------------------+
| Stage 3                   |  depths[2] blocks at embed_dim * 4
+---------------------------+
      |
+-----v--------------------+
| Patch Merging             |
+---------------------------+
      |
+-----v--------------------+
| Stage 4                   |  depths[3] blocks at embed_dim * 8
+---------------------------+
      |
+-----v--------------------+
| Global Average Pooling    |
+---------------------------+
      |
+-----v--------------------+
| LayerNorm                 |
+---------------------------+
      |
+-----v--------------------+
| Optional Classifier       |
+---------------------------+

Window Attention

Attention is computed within non-overlapping local windows of M x M tokens, reducing complexity from O(N^2) to O(N * M^2). Shifted windows in alternating layers enable cross-window information flow via cyclic shift and masked attention.

Features:

  • Real window partitioning with M x M local attention
  • Multi-head scaled dot-product attention within each window
  • Cyclic shift for shifted window attention with boundary masking
  • Learnable relative position bias per attention head

Usage

# Swin-Tiny
model = SwinTransformer.build(
  image_size: 224,
  patch_size: 4,
  embed_dim: 96,
  depths: [2, 2, 6, 2],
  num_heads: [3, 6, 12, 24],
  num_classes: 1000
)

References

  • "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (Liu et al., ICCV 2021)

Summary

Types

Options for build/1.

Functions

Build a Swin Transformer model.

Compute relative position bias for window attention.

Compute attention mask for shifted windows.

Cyclic shift: roll tensor by -shift_size along both H and W axes.

Get the output size of a Swin Transformer model.

Reverse cyclic shift: roll tensor by +shift_size along both H and W axes.

Partition a spatial tensor into non-overlapping windows.

Reverse window partition back to spatial layout.

Types

build_opt()

@type build_opt() ::
  {:depths, [pos_integer()]}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, pos_integer()}
  | {:patch_size, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Swin Transformer model.

Options

  • :image_size - Input image size, square (default: 224)
  • :patch_size - Initial patch embedding size (default: 4)
  • :in_channels - Number of input channels (default: 3)
  • :embed_dim - Base embedding dimension (default: 96)
  • :depths - Number of blocks per stage (default: [2, 2, 6, 2])
  • :num_heads - Number of attention heads per stage (default: [3, 6, 12, 24])
  • :window_size - Window size for local attention (default: 7)
  • :mlp_ratio - MLP hidden dim ratio (default: 4.0)
  • :dropout - Dropout rate (default: 0.0)
  • :num_classes - Number of classes for classification head (optional)

Spatial dimensions at each stage must be divisible by the effective window size.

Returns

An Axon model. Without :num_classes, outputs [batch, embed_dim * 2^(num_stages-1)]. With :num_classes, outputs [batch, num_classes].

compute_relative_position_bias(window_size, num_heads)

@spec compute_relative_position_bias(pos_integer(), pos_integer()) :: Nx.Tensor.t()

Compute relative position bias for window attention.

Uses distance-based decay with per-head geometric slopes, similar to ALiBi but for 2D windows. Each head gets a different slope, providing diverse position sensitivity across heads.

Returns a [1, num_heads, wsws, wsws] bias tensor.

compute_shift_mask(h, w, window_size, shift_size)

@spec compute_shift_mask(pos_integer(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()

Compute attention mask for shifted windows.

Assigns region IDs based on position relative to shift boundaries, then creates a pairwise mask that blocks attention between tokens from different regions within each window.

Returns a [num_windows, wsws, wsws] mask tensor with 0.0 for allowed attention and -100.0 for blocked attention.

cyclic_shift(input, shift_size, h, w)

@spec cyclic_shift(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()

Cyclic shift: roll tensor by -shift_size along both H and W axes.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a Swin Transformer model.

Returns :num_classes if set, otherwise embed_dim * 2^(num_stages - 1).

reverse_cyclic_shift(input, shift_size, h, w)

@spec reverse_cyclic_shift(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()

Reverse cyclic shift: roll tensor by +shift_size along both H and W axes.

window_partition(input, ws, h, w)

@spec window_partition(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()

Partition a spatial tensor into non-overlapping windows.

Input: [B, H, W, C] -> Output: [BnW, wsws, C]

window_reverse(input, ws, h, w, batch)

@spec window_reverse(
  Nx.Tensor.t(),
  pos_integer(),
  pos_integer(),
  pos_integer(),
  pos_integer()
) ::
  Nx.Tensor.t()

Reverse window partition back to spatial layout.

Input: [BnW, wsws, C] -> Output: [B, H, W, C]