# `Edifice.Vision.SwinTransformer`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/vision/swin.ex#L1)

Swin Transformer (Shifted Window Transformer) implementation.

A hierarchical vision transformer that computes attention within local
windows and shifts windows between layers for cross-window connections.
Produces multi-scale feature maps like a CNN, making it suitable for
dense prediction tasks.

## Architecture

```
Image [batch, channels, height, width]
      |
+-----v--------------------+
| Patch Embedding           |  patch_size x patch_size, linear project
+---------------------------+
      |
      v
[batch, H/4 * W/4, embed_dim]
      |
+-----v--------------------+
| Stage 1                   |  depths[0] Swin blocks at embed_dim
|   Window Attention        |  Alternating regular/shifted windows
+---------------------------+
      |
+-----v--------------------+
| Patch Merging             |  2x2 spatial merge, 2x channel expand
+---------------------------+
      |
+-----v--------------------+
| Stage 2                   |  depths[1] blocks at embed_dim * 2
+---------------------------+
      |
+-----v--------------------+
| Patch Merging             |
+---------------------------+
      |
+-----v--------------------+
| Stage 3                   |  depths[2] blocks at embed_dim * 4
+---------------------------+
      |
+-----v--------------------+
| Patch Merging             |
+---------------------------+
      |
+-----v--------------------+
| Stage 4                   |  depths[3] blocks at embed_dim * 8
+---------------------------+
      |
+-----v--------------------+
| Global Average Pooling    |
+---------------------------+
      |
+-----v--------------------+
| LayerNorm                 |
+---------------------------+
      |
+-----v--------------------+
| Optional Classifier       |
+---------------------------+
```

## Window Attention

Attention is computed within non-overlapping local windows of M x M tokens,
reducing complexity from O(N^2) to O(N * M^2). Shifted windows in alternating
layers enable cross-window information flow via cyclic shift and masked attention.

Features:
- Real window partitioning with M x M local attention
- Multi-head scaled dot-product attention within each window
- Cyclic shift for shifted window attention with boundary masking
- Learnable relative position bias per attention head

## Usage

    # Swin-Tiny
    model = SwinTransformer.build(
      image_size: 224,
      patch_size: 4,
      embed_dim: 96,
      depths: [2, 2, 6, 2],
      num_heads: [3, 6, 12, 24],
      num_classes: 1000
    )

## References

- "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows"
  (Liu et al., ICCV 2021)

# `build_opt`

```elixir
@type build_opt() ::
  {:depths, [pos_integer()]}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, pos_integer()}
  | {:patch_size, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Swin Transformer model.

## Options

  - `:image_size` - Input image size, square (default: 224)
  - `:patch_size` - Initial patch embedding size (default: 4)
  - `:in_channels` - Number of input channels (default: 3)
  - `:embed_dim` - Base embedding dimension (default: 96)
  - `:depths` - Number of blocks per stage (default: [2, 2, 6, 2])
  - `:num_heads` - Number of attention heads per stage (default: [3, 6, 12, 24])
  - `:window_size` - Window size for local attention (default: 7)
  - `:mlp_ratio` - MLP hidden dim ratio (default: 4.0)
  - `:dropout` - Dropout rate (default: 0.0)
  - `:num_classes` - Number of classes for classification head (optional)

Spatial dimensions at each stage must be divisible by the effective window size.

## Returns

  An Axon model. Without `:num_classes`, outputs `[batch, embed_dim * 2^(num_stages-1)]`.
  With `:num_classes`, outputs `[batch, num_classes]`.

# `compute_relative_position_bias`

```elixir
@spec compute_relative_position_bias(pos_integer(), pos_integer()) :: Nx.Tensor.t()
```

Compute relative position bias for window attention.

Uses distance-based decay with per-head geometric slopes, similar to ALiBi
but for 2D windows. Each head gets a different slope, providing diverse
position sensitivity across heads.

Returns a [1, num_heads, ws*ws, ws*ws] bias tensor.

# `compute_shift_mask`

```elixir
@spec compute_shift_mask(pos_integer(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()
```

Compute attention mask for shifted windows.

Assigns region IDs based on position relative to shift boundaries,
then creates a pairwise mask that blocks attention between tokens
from different regions within each window.

Returns a [num_windows, ws*ws, ws*ws] mask tensor with 0.0 for
allowed attention and -100.0 for blocked attention.

# `cyclic_shift`

```elixir
@spec cyclic_shift(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()
```

Cyclic shift: roll tensor by -shift_size along both H and W axes.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output size of a Swin Transformer model.

Returns `:num_classes` if set, otherwise `embed_dim * 2^(num_stages - 1)`.

# `reverse_cyclic_shift`

```elixir
@spec reverse_cyclic_shift(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()
```

Reverse cyclic shift: roll tensor by +shift_size along both H and W axes.

# `window_partition`

```elixir
@spec window_partition(Nx.Tensor.t(), pos_integer(), pos_integer(), pos_integer()) ::
  Nx.Tensor.t()
```

Partition a spatial tensor into non-overlapping windows.

Input: [B, H, W, C] -> Output: [B*nW, ws*ws, C]

# `window_reverse`

```elixir
@spec window_reverse(
  Nx.Tensor.t(),
  pos_integer(),
  pos_integer(),
  pos_integer(),
  pos_integer()
) ::
  Nx.Tensor.t()
```

Reverse window partition back to spatial layout.

Input: [B*nW, ws*ws, C] -> Output: [B, H, W, C]

---

*Consult [api-reference.md](api-reference.md) for complete listing*
