# `Edifice.SSM.Mamba`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/ssm/mamba.ex#L1)

Mamba: True Selective State Space Model with optimized parallel scan.

Implements the Mamba architecture from "Mamba: Linear-Time Sequence Modeling
with Selective State Spaces" (Gu & Dao, 2023).

## Key Innovation: Parallel Associative Scan

The SSM recurrence h[t] = A * h[t-1] + B * x[t] seems sequential, but can be
parallelized using associativity:

```
Define: (a, b) ⊗ (c, d) = (a*c, a*d + b)

Then the scan:
  h[0] = B[0] * x[0]
  h[1] = A[1] * h[0] + B[1] * x[1]
  h[2] = A[2] * h[1] + B[2] * x[2]
  ...

Can be computed in O(log L) parallel time using prefix scan.
```

## Selective Mechanism

Unlike linear time-invariant SSMs, Mamba makes A, B, C input-dependent:
- Δ (discretization step) controls how much to update state
- B (input matrix) projects input to state space
- C (output matrix) projects state to output
- These are computed from the input, enabling selective focus

## Architecture

```
Input [batch, seq_len, embed_dim]
      │
      ▼
┌─────────────────────────────────────┐
│         Mamba Block                  │
│                                      │
│  ┌──── Linear (expand) ────┐        │
│  │           │              │        │
│  │   DepthwiseConv + SiLU   │        │
│  │           │              │        │
│  │   Parallel Scan SSM  Linear+SiLU  │
│  │           │              │        │
│  └───────── multiply ───────┘        │
│               │                      │
│         Linear (project)             │
└─────────────────────────────────────┘
      │
      ▼ (repeat for num_layers)
```

## Usage

    # Build Mamba backbone
    model = Mamba.build(
      embed_dim: 287,
      hidden_size: 256,
      state_size: 16,
      num_layers: 2,
      expand_factor: 2
    )

## References
- Paper: https://arxiv.org/abs/2312.00752
- Original code: https://github.com/state-spaces/mamba

# `build_opt`

```elixir
@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:state_size, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Mamba model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension D (default: 256)
  - `:state_size` - SSM state dimension N (default: 16)
  - `:expand_factor` - Expansion factor E for inner dim (default: 2)
  - `:conv_size` - 1D convolution kernel size (default: 4)
  - `:num_layers` - Number of Mamba blocks (default: 2)
  - `:dropout` - Dropout rate (default: 0.0)
  - `:window_size` - Expected sequence length for JIT optimization (default: 60)

## Returns
  An Axon model that processes sequences and outputs the last hidden state.

# `build_depthwise_conv1d`

```elixir
@spec build_depthwise_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()
```

Build depthwise separable 1D convolution layer.

# `build_mamba_block`

```elixir
@spec build_mamba_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single Mamba block with parallel scan SSM.

## Options
  - `:hidden_size` - Internal dimension D
  - `:state_size` - SSM state dimension N
  - `:expand_factor` - Expansion factor E
  - `:conv_size` - Convolution kernel size
  - `:name` - Layer name prefix

# `build_selective_ssm_parallel`

```elixir
@spec build_selective_ssm_parallel(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the Selective SSM with parallel associative scan.

This is the core of Mamba: an SSM where A, B, C, Δ are input-dependent,
computed efficiently using parallel scan.

The discretized SSM equations:
- A_bar = exp(Δ * A)
- B_bar = Δ * B
- h[t] = A_bar * h[t-1] + B_bar * x[t]
- y[t] = C * h[t]

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a Mamba model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a Mamba model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for real-time sequence processing (60fps).

---

*Consult [api-reference.md](api-reference.md) for complete listing*
