# `Edifice.SSM.Common`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/ssm/common.ex#L1)

Shared components for all Mamba architecture variants.

This module contains the common building blocks used across Mamba variants:
- Default hyperparameters
- Model structure builders (input projection, layer stacking, last timestep)
- Block structure (normalization, projections, gating)
- Depthwise convolution
- SSM parameter projections
- SSM discretization
- Utility functions

## Mamba Variants

All variants share the same architecture, differing only in the scan algorithm:

| Variant | Scan Algorithm | Notes |
|---------|---------------|-------|
| `Mamba` | Blelloch | Work-efficient O(L) work, O(log L) depth |
| `MambaHillisSteele` | Hillis-Steele | O(L log L) work, more parallelism |
| `MambaCumsum` | Cumsum-based | Experimental log-space approach |
| `MambaSSD` | SSD chunked | Mamba-2's matmul approach |
## See Also

- `Edifice.SSM.Mamba` - Main Mamba implementation

# `blelloch_scan`

```elixir
@spec blelloch_scan(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Blelloch parallel scan (work-efficient O(L) work, O(log L) depth).

Uses Enum.reduce for the loop - this lets XLA JIT each level efficiently.

## Parameters

- `a` - Decay factors `[batch, seq_len, hidden_size, state_size]`
- `b` - Input contributions `[batch, seq_len, hidden_size, state_size]`

## Returns

Hidden states `[batch, seq_len, hidden_size, state_size]`

# `build_block`

```elixir
@spec build_block(Axon.t(), keyword(), (Axon.t(), keyword() -&gt; Axon.t())) :: Axon.t()
```

Build the common Mamba block structure.

This handles everything except the SSM scan itself:
- Layer normalization
- Input projection (to 2x inner_size for x/z branches)
- X/Z branch splitting
- Depthwise convolution + SiLU on X branch
- SiLU gating on Z branch
- Gated multiplication
- Output projection

The caller provides an `ssm_builder` function that constructs the SSM layer.

## Parameters

- `input` - Input Axon node
- `opts` - Block options (hidden_size, state_size, expand_factor, conv_size, name)
- `ssm_builder` - Function `(x_activated, ssm_opts) -> Axon.t()` that builds SSM

## Returns

An Axon node representing the block output.

# `build_depthwise_conv1d`

```elixir
@spec build_depthwise_conv1d(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()
```

Build a depthwise separable 1D convolution layer.

True Mamba uses learned depthwise convolution, not mean pooling.
This approximates depthwise conv behavior for SSM input processing.

## Parameters

- `input` - Input Axon node `[batch, seq_len, channels]`
- `channels` - Number of output channels
- `kernel_size` - Convolution kernel size
- `name` - Layer name prefix

# `build_model`

```elixir
@spec build_model(
  keyword(),
  (Axon.t(), keyword() -&gt; Axon.t())
) :: Axon.t()
```

Build the common Mamba model structure.

This handles:
- Input projection (if embed_dim != hidden_size)
- Layer stacking with residual connections and dropout
- Last timestep extraction

The caller provides a `block_builder` function that constructs each Mamba block.

## Parameters

- `opts` - Model options (embed_dim, hidden_size, num_layers, dropout, etc.)
- `block_builder` - Function `(input, opts) -> Axon.t()` that builds one block

## Returns

An Axon model that outputs `[batch, hidden_size]`.

# `build_ssm_projections`

```elixir
@spec build_ssm_projections(
  Axon.t(),
  keyword()
) :: {Axon.t(), Axon.t(), Axon.t()}
```

Build the SSM parameter projections (B, C, dt).

These are the "selective" parameters that make Mamba input-dependent:
- B: Input matrix `[batch, seq_len, state_size]`
- C: Output matrix `[batch, seq_len, state_size]`
- dt: Discretization step `[batch, seq_len, hidden_size]`

## Parameters

- `input` - Input Axon node `[batch, seq_len, hidden_size]`
- `opts` - Options (hidden_size, state_size, dt_rank, name)

## Returns

Tuple of `{b_matrix, c_matrix, dt_proj}` Axon nodes.

# `compute_ssm_output`

```elixir
@spec compute_ssm_output(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the SSM output from hidden states.

y[t] = C[t] * h[t]

## Parameters

- `h` - Hidden states `[batch, seq_len, hidden_size, state_size]`
- `c` - C matrix `[batch, seq_len, state_size]`

## Returns

Output tensor `[batch, seq_len, hidden_size]`

# `default_conv_size`

```elixir
@spec default_conv_size() :: pos_integer()
```

Default convolution kernel size

# `default_dropout`

```elixir
@spec default_dropout() :: float()
```

Default dropout rate

# `default_expand_factor`

```elixir
@spec default_expand_factor() :: pos_integer()
```

Default expansion factor E

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default hidden dimension D

# `default_num_layers`

```elixir
@spec default_num_layers() :: pos_integer()
```

Default number of Mamba blocks

# `default_state_size`

```elixir
@spec default_state_size() :: pos_integer()
```

Default SSM state dimension N

# `discretize_ssm`

```elixir
@spec discretize_ssm(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), pos_integer()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}
```

Discretize the SSM parameters for the scan.

Converts continuous-time SSM to discrete-time:
- A_bar = exp(Δ * A)
- B_bar = Δ * B
- Bx = B_bar * x

## Parameters

- `x` - Input tensor `[batch, seq_len, hidden_size]`
- `b` - B matrix `[batch, seq_len, state_size]`
- `dt` - Delta tensor `[batch, seq_len, hidden_size]`
- `state_size` - SSM state dimension

## Returns

Tuple of `{a_bar, bx}` where:
- `a_bar`: `[batch, seq_len, hidden_size, state_size]` - decay factors
- `bx`: `[batch, seq_len, hidden_size, state_size]` - input contributions

# `dt_max`

```elixir
@spec dt_max() :: float()
```

Maximum delta

# `dt_min`

```elixir
@spec dt_min() :: float()
```

Minimum delta for numerical stability

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a Mamba model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a Mamba model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for real-time sequence processing (60fps).

# `sequential_scan`

```elixir
@spec sequential_scan(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Sequential scan for short sequences or fallback.

Computes h[t] = a[t] * h[t-1] + b[t] for all t.

## Parameters

- `a` - Decay factors `[batch, seq_len, hidden_size, state_size]`
- `b` - Input contributions `[batch, seq_len, hidden_size, state_size]`

## Returns

Hidden states `[batch, seq_len, hidden_size, state_size]`

---

*Consult [api-reference.md](api-reference.md) for complete listing*
