# `Edifice.SSM.GatedSSM`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/ssm/gated_ssm.ex#L1)

GatedSSM: Simplified gated temporal network inspired by state space models.

**NOTE**: This is NOT a true Mamba implementation. It uses a simplified gating
mechanism instead of the parallel associative scan that makes Mamba efficient.
For true Mamba, see `Edifice.SSM.Mamba`.

This module provides competitive results and is numerically stable. Use it when
you want a lightweight temporal model that's simpler than true Mamba.

## How It Differs From True Mamba

| Aspect | True Mamba | GatedSSM |
|--------|------------|----------|
| Core algorithm | Parallel associative scan | Gated multiplication |
| Recurrence | h(t) = A*h(t-1) + B*x | Sigmoid gating approximation |
| Convolution | Learned depthwise separable | Mean pooling + projection |
| Complexity | O(L) parallel | O(L) sequential approximation |

## Architecture

```
Input [batch, seq_len, embed_dim]
      │
      ▼
┌─────────────────────────────────────┐
│         GatedSSM Block              │
│                                      │
│  ┌──── Linear (expand) ────┐        │
│  │           │              │        │
│  │   MeanPool + SiLU        │        │
│  │           │              │        │
│  │   Gated Context     Linear+SiLU   │
│  │           │              │        │
│  └───────── multiply ───────┘        │
│               │                      │
│         Linear (project)             │
└─────────────────────────────────────┘
      │
      ▼ (repeat for num_layers)
      │
      ▼
[batch, seq_len, embed_dim] -> last timestep -> [batch, embed_dim]
```

## Usage

    # Build GatedSSM backbone
    model = GatedSSM.build(
      embed_dim: 256,
      hidden_size: 256,
      state_size: 16,
      num_layers: 2,
      expand_factor: 2
    )

## When To Use

- Lightweight temporal processing without full Mamba complexity
- Stable training (no NaN issues observed)
- When true Mamba isn't available or needed

# `build_opt`

```elixir
@type build_opt() ::
  {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Mamba model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension D (default: 256)
  - `:state_size` - SSM state dimension N (default: 16)
  - `:expand_factor` - Expansion factor E for inner dim (default: 2)
  - `:conv_size` - 1D convolution kernel size (default: 4)
  - `:num_layers` - Number of Mamba blocks (default: 2)
  - `:dropout` - Dropout rate (default: 0.0)
  - `:window_size` - Expected sequence length for JIT optimization (default: 60)

## Returns
  An Axon model that processes sequences and outputs the last hidden state.

# `build_causal_conv1d`

```elixir
@spec build_causal_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()
```

Build a causal 1D convolution layer.

Applies convolution only over past timesteps (causal padding).
Uses a simplified approach with sliding window mean + learned projection.

# `build_checkpointed`

```elixir
@spec build_checkpointed(keyword()) :: Axon.t()
```

Build a Mamba model with gradient checkpointing for memory-efficient training.

Same as `build/1` but applies gradient checkpointing to each Mamba block,
reducing memory usage at the cost of ~30% more compute.

## Memory Savings

For a 3-layer Mamba with window_size=60, batch_size=256:
- Without checkpointing: ~2.5GB activation memory
- With checkpointing: ~0.8GB activation memory

## When to Use

- Training on GPU with limited VRAM
- Using large batch sizes or long sequences
- When you're hitting OOM during training

## Options

Same as `build/1`, plus:
  - `:checkpoint_every` - Checkpoint every N layers (default: 1)

# `build_mamba_block`

```elixir
@spec build_mamba_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single Mamba block.

The Mamba block consists of:
1. Two parallel branches after input projection
2. One branch: Conv1D -> SiLU -> Selective SSM
3. Other branch: Linear -> SiLU (gating)
4. Multiply outputs -> Project back

## Options
  - `:hidden_size` - Internal dimension D
  - `:state_size` - SSM state dimension N
  - `:expand_factor` - Expansion factor E
  - `:conv_size` - Convolution kernel size
  - `:name` - Layer name prefix

# `build_selective_ssm`

```elixir
@spec build_selective_ssm(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the Selective State Space Model (S6).

This is the core of Mamba: an SSM where the A, B, C parameters
are computed from the input, making it "selective".

The SSM equations:
- h(t) = exp(delta * A) * h(t-1) + delta * B * x(t)
- y(t) = C * h(t)

Where delta, B, C are input-dependent projections.

# `init_cache`

```elixir
@spec init_cache(keyword()) :: map()
```

Initialize hidden state for incremental inference.

Returns a map containing the cached state for each layer.
For each layer, we cache:
- `:h` - The SSM hidden state [batch, state_size]
- `:conv_buffer` - Buffer for causal convolution [batch, conv_size-1, inner_size]

## Options
  - `:batch_size` - Batch size (default: 1)
  - `:hidden_size` - Hidden dimension D (default: 256)
  - `:state_size` - SSM state dimension N (default: 16)
  - `:expand_factor` - Expansion factor E (default: 2)
  - `:conv_size` - Convolution kernel size (default: 4)
  - `:num_layers` - Number of Mamba blocks (default: 2)

## Example

    cache = GatedSSM.init_cache(batch_size: 1, hidden_size: 256)
    {output, new_cache} = GatedSSM.step(x_single_frame, params, cache, opts)

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a Mamba model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a Mamba model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for real-time sequence processing (60fps).

# `step`

```elixir
@spec step(Nx.Tensor.t(), map(), map(), keyword()) :: {Nx.Tensor.t(), map()}
```

Perform a single incremental step with cached state.

Takes a single frame input and the current cache, returns the output
and updated cache. This enables O(1) inference per frame instead of
O(window_size).

## Arguments
  - `x` - Single frame input [batch, hidden_size] or [batch, 1, hidden_size]
  - `params` - Model parameters (from trained model)
  - `cache` - Cache from `init_cache/1` or previous `step/4` call

## Returns
  `{output, new_cache}` where:
  - `output` - [batch, hidden_size] tensor
  - `new_cache` - Updated cache for next step

## Example

    cache = GatedSSM.init_cache(hidden_size: 256)
    {out1, cache} = GatedSSM.step(frame1, params, cache)
    {out2, cache} = GatedSSM.step(frame2, params, cache)
    # out2 is equivalent to running [frame1, frame2] through full model

---

*Consult [api-reference.md](api-reference.md) for complete listing*
