# `Edifice.Vision.MambaVision`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/vision/mamba_vision.ex#L1)

MambaVision: A Hybrid Mamba-Transformer Vision Backbone.

Implements the MambaVision architecture from "MambaVision: A Hybrid
Mamba-Transformer Vision Backbone" (Hatamizadeh & Kautz, NVIDIA, 2024).
A hierarchical 4-stage vision backbone that uses CNN blocks in early
stages and hybrid Mamba SSM + windowed self-attention in later stages.

## Key Innovation

Instead of applying Mamba uniformly (like Vim/VMamba), MambaVision uses a
stage-appropriate mix:
- **Stages 1-2**: Pure CNN blocks (fast at high resolution)
- **Stages 3-4**: First half Mamba SSM, second half windowed attention

The MambaVisionMixer modifies standard Mamba with:
1. **Non-causal convolution** (no directional bias for 2D data)
2. **Dual-branch**: SSM on half channels, symmetric Conv+SiLU on other half
3. **Concatenation** instead of multiplicative gating

## Architecture

```
Input: (B, 3, 224, 224)
  -> PatchEmbed (2x Conv3x3 stride 2 = 4x downsample)
  -> Stage 1 (ConvBlocks)           -> Downsample (Conv stride 2)
  -> Stage 2 (ConvBlocks)           -> Downsample
  -> Stage 3 (Mamba + Attention)    -> Downsample
  -> Stage 4 (Mamba + Attention)
  -> LayerNorm -> Global Avg Pool -> Linear -> Output
```

Channel progression: dim -> 2*dim -> 4*dim -> 8*dim

## Model Variants

| Variant | dim | depths     | Params |
|---------|-----|------------|--------|
| Tiny    | 80  | [1,3,8,4]  | ~32M   |
| Small   | 96  | [3,3,7,5]  | ~50M   |
| Base    | 128 | [3,3,10,5] | ~98M   |

## Usage

    model = MambaVision.build(
      image_size: 224,
      dim: 80,
      depths: [1, 3, 8, 4],
      num_heads: [2, 4, 8, 16],
      num_classes: 10
    )

## References
- Paper: https://arxiv.org/abs/2407.08083
- Code: https://github.com/NVlabs/MambaVision

# `build_opt`

```elixir
@type build_opt() ::
  {:d_conv, pos_integer()}
  | {:d_state, pos_integer()}
  | {:depths, [pos_integer()]}
  | {:dim, pos_integer()}
  | {:dropout, float()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:mlp_ratio, pos_integer()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, [pos_integer()]}
```

Options for `build/1`.

# `base_config`

```elixir
@spec base_config() :: keyword()
```

Get the Base variant configuration.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a MambaVision model.

## Options
  - `:image_size` - Input image size, square (default: 224)
  - `:in_channels` - Number of input channels (default: 3)
  - `:dim` - Base channel dimension, doubles each stage (default: 80)
  - `:depths` - Number of blocks per stage (default: [1, 3, 8, 4])
  - `:num_heads` - Attention heads per stage (default: [2, 4, 8, 16])
  - `:mlp_ratio` - MLP expansion ratio in hybrid stages (default: 4)
  - `:dropout` - Dropout/drop path rate (default: 0.0)
  - `:d_state` - SSM state dimension (default: 8)
  - `:d_conv` - SSM convolution kernel size (default: 3)
  - `:num_classes` - Classification head size (optional)

## Returns
  Without `:num_classes`: `[batch, 8*dim]` feature vector.
  With `:num_classes`: `[batch, num_classes]` logits.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a MambaVision model.

# `small_config`

```elixir
@spec small_config() :: keyword()
```

Get the Small variant configuration.

# `tiny_config`

```elixir
@spec tiny_config() :: keyword()
```

Get the Tiny variant configuration.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
