Edifice.Vision.MambaVision (Edifice v0.2.0)

Copy Markdown View Source

MambaVision: A Hybrid Mamba-Transformer Vision Backbone.

Implements the MambaVision architecture from "MambaVision: A Hybrid Mamba-Transformer Vision Backbone" (Hatamizadeh & Kautz, NVIDIA, 2024). A hierarchical 4-stage vision backbone that uses CNN blocks in early stages and hybrid Mamba SSM + windowed self-attention in later stages.

Key Innovation

Instead of applying Mamba uniformly (like Vim/VMamba), MambaVision uses a stage-appropriate mix:

  • Stages 1-2: Pure CNN blocks (fast at high resolution)
  • Stages 3-4: First half Mamba SSM, second half windowed attention

The MambaVisionMixer modifies standard Mamba with:

  1. Non-causal convolution (no directional bias for 2D data)
  2. Dual-branch: SSM on half channels, symmetric Conv+SiLU on other half
  3. Concatenation instead of multiplicative gating

Architecture

Input: (B, 3, 224, 224)
  -> PatchEmbed (2x Conv3x3 stride 2 = 4x downsample)
  -> Stage 1 (ConvBlocks)           -> Downsample (Conv stride 2)
  -> Stage 2 (ConvBlocks)           -> Downsample
  -> Stage 3 (Mamba + Attention)    -> Downsample
  -> Stage 4 (Mamba + Attention)
  -> LayerNorm -> Global Avg Pool -> Linear -> Output

Channel progression: dim -> 2dim -> 4dim -> 8*dim

Model Variants

VariantdimdepthsParams
Tiny80[1,3,8,4]~32M
Small96[3,3,7,5]~50M
Base128[3,3,10,5]~98M

Usage

model = MambaVision.build(
  image_size: 224,
  dim: 80,
  depths: [1, 3, 8, 4],
  num_heads: [2, 4, 8, 16],
  num_classes: 10
)

References

Summary

Types

Options for build/1.

Functions

Get the Base variant configuration.

Build a MambaVision model.

Get the output size of a MambaVision model.

Get the Small variant configuration.

Get the Tiny variant configuration.

Types

build_opt()

@type build_opt() ::
  {:d_conv, pos_integer()}
  | {:d_state, pos_integer()}
  | {:depths, [pos_integer()]}
  | {:dim, pos_integer()}
  | {:dropout, float()}
  | {:image_size, pos_integer()}
  | {:in_channels, pos_integer()}
  | {:mlp_ratio, pos_integer()}
  | {:num_classes, pos_integer() | nil}
  | {:num_heads, [pos_integer()]}

Options for build/1.

Functions

base_config()

@spec base_config() :: keyword()

Get the Base variant configuration.

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a MambaVision model.

Options

  • :image_size - Input image size, square (default: 224)
  • :in_channels - Number of input channels (default: 3)
  • :dim - Base channel dimension, doubles each stage (default: 80)
  • :depths - Number of blocks per stage (default: [1, 3, 8, 4])
  • :num_heads - Attention heads per stage (default: [2, 4, 8, 16])
  • :mlp_ratio - MLP expansion ratio in hybrid stages (default: 4)
  • :dropout - Dropout/drop path rate (default: 0.0)
  • :d_state - SSM state dimension (default: 8)
  • :d_conv - SSM convolution kernel size (default: 3)
  • :num_classes - Classification head size (optional)

Returns

Without :num_classes: [batch, 8*dim] feature vector. With :num_classes: [batch, num_classes] logits.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a MambaVision model.

small_config()

@spec small_config() :: keyword()

Get the Small variant configuration.

tiny_config()

@spec tiny_config() :: keyword()

Get the Tiny variant configuration.