MambaVision: A Hybrid Mamba-Transformer Vision Backbone.
Implements the MambaVision architecture from "MambaVision: A Hybrid Mamba-Transformer Vision Backbone" (Hatamizadeh & Kautz, NVIDIA, 2024). A hierarchical 4-stage vision backbone that uses CNN blocks in early stages and hybrid Mamba SSM + windowed self-attention in later stages.
Key Innovation
Instead of applying Mamba uniformly (like Vim/VMamba), MambaVision uses a stage-appropriate mix:
- Stages 1-2: Pure CNN blocks (fast at high resolution)
- Stages 3-4: First half Mamba SSM, second half windowed attention
The MambaVisionMixer modifies standard Mamba with:
- Non-causal convolution (no directional bias for 2D data)
- Dual-branch: SSM on half channels, symmetric Conv+SiLU on other half
- Concatenation instead of multiplicative gating
Architecture
Input: (B, 3, 224, 224)
-> PatchEmbed (2x Conv3x3 stride 2 = 4x downsample)
-> Stage 1 (ConvBlocks) -> Downsample (Conv stride 2)
-> Stage 2 (ConvBlocks) -> Downsample
-> Stage 3 (Mamba + Attention) -> Downsample
-> Stage 4 (Mamba + Attention)
-> LayerNorm -> Global Avg Pool -> Linear -> OutputChannel progression: dim -> 2dim -> 4dim -> 8*dim
Model Variants
| Variant | dim | depths | Params |
|---|---|---|---|
| Tiny | 80 | [1,3,8,4] | ~32M |
| Small | 96 | [3,3,7,5] | ~50M |
| Base | 128 | [3,3,10,5] | ~98M |
Usage
model = MambaVision.build(
image_size: 224,
dim: 80,
depths: [1, 3, 8, 4],
num_heads: [2, 4, 8, 16],
num_classes: 10
)References
Summary
Functions
Get the Base variant configuration.
Build a MambaVision model.
Get the output size of a MambaVision model.
Get the Small variant configuration.
Get the Tiny variant configuration.
Types
@type build_opt() :: {:d_conv, pos_integer()} | {:d_state, pos_integer()} | {:depths, [pos_integer()]} | {:dim, pos_integer()} | {:dropout, float()} | {:image_size, pos_integer()} | {:in_channels, pos_integer()} | {:mlp_ratio, pos_integer()} | {:num_classes, pos_integer() | nil} | {:num_heads, [pos_integer()]}
Options for build/1.
Functions
@spec base_config() :: keyword()
Get the Base variant configuration.
Build a MambaVision model.
Options
:image_size- Input image size, square (default: 224):in_channels- Number of input channels (default: 3):dim- Base channel dimension, doubles each stage (default: 80):depths- Number of blocks per stage (default: [1, 3, 8, 4]):num_heads- Attention heads per stage (default: [2, 4, 8, 16]):mlp_ratio- MLP expansion ratio in hybrid stages (default: 4):dropout- Dropout/drop path rate (default: 0.0):d_state- SSM state dimension (default: 8):d_conv- SSM convolution kernel size (default: 3):num_classes- Classification head size (optional)
Returns
Without :num_classes: [batch, 8*dim] feature vector.
With :num_classes: [batch, num_classes] logits.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a MambaVision model.
@spec small_config() :: keyword()
Get the Small variant configuration.
@spec tiny_config() :: keyword()
Get the Tiny variant configuration.