# `Edifice.SSM.Zamba`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/ssm/zamba.ex#L1)

Zamba: Mamba with Single Shared Attention layer.

Named after Zyphra's Zamba architecture, this module implements a more efficient
hybrid than Jamba by using a single shared attention layer that is applied at
regular intervals throughout the Mamba stack.

## Key Difference from Jamba

| Aspect | Jamba | Zamba |
|--------|-------|-------|
| Attention layers | Multiple (interleaved) | One (shared weights, reused) |
| KV cache | O(L x N_attn) | O(L) - 10x reduction |
| Parameters | Higher | Lower |
| Pattern | [M, M, A, M, M, A] | [M, M, M, M, M, M] -> A (shared) |

## Architecture Pattern

```
Input [batch, seq_len, embed_dim]
      │
      ▼
┌─────────────────────────────────────┐
│  Mamba Block 1                       │
├─────────────────────────────────────┤
│  Mamba Block 2                       │
├─────────────────────────────────────┤
│  ...                                 │
├─────────────────────────────────────┤
│  Mamba Block N                       │
└──────────────┬──────────────────────┘
               │
               ▼
┌─────────────────────────────────────┐
│  Shared Attention (applied every K) │ <- Same weights, reused
└──────────────┬──────────────────────┘
               │
               ▼
[batch, hidden_size]
```

## Why Single Shared Attention Works

The insight from Zamba: attention layers primarily serve to propagate
information globally, not to learn diverse patterns. A single layer
with shared weights can achieve similar global information flow at
a fraction of the parameter cost.

## Usage

    # Default Zamba (6 Mamba layers, 1 shared attention applied every 3)
    model = Zamba.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 6,
      attention_every: 3
    )

    # Minimal attention variant (attention only at the end)
    model = Zamba.build(
      embed_dim: 287,
      num_layers: 6,
      attention_every: 6  # Only applied after final layer
    )

## Reference

- Paper: "Zamba: A Compact 7B SSM Hybrid Model" (arXiv:2405.16712)
- Key insight: Single shared attention achieves 10x KV cache reduction

# `build_opt`

```elixir
@type build_opt() ::
  {:attention_every, pos_integer()}
  | {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `apply_shared_attention`

```elixir
@spec apply_shared_attention(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Apply the single shared attention layer.

The key insight: by using the same layer name ("shared_attention"),
Axon will reuse the same parameters for all applications of this layer.
This is what makes Zamba different from Jamba - one set of attention
weights, applied multiple times.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a Zamba model (Mamba + Single Shared Attention).

## Options

**Architecture:**
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_layers` - Number of Mamba layers (default: 6)
  - `:attention_every` - Apply shared attention every N Mamba layers (default: 3)

**Mamba-specific:**
  - `:state_size` - SSM state dimension (default: 16)
  - `:expand_factor` - Mamba expansion factor (default: 2)
  - `:conv_size` - Causal conv kernel size (default: 4)

**Shared Attention:**
  - `:num_heads` - Number of attention heads (default: 4)
  - `:head_dim` - Dimension per attention head (default: 64)
  - `:window_size` - Attention window size (default: 60)

**General:**
  - `:dropout` - Dropout rate (default: 0.1)
  - `:seq_len` - Fixed sequence length for JIT optimization (default: window_size)

## Returns

  An Axon model that outputs [batch, hidden_size] from the last position.

## Example

    model = Zamba.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 6,
      attention_every: 3  # Shared attention applied 2x total
    )

# `build_mamba_layer`

```elixir
@spec build_mamba_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a Mamba layer with residual connection.

Uses Pre-LayerNorm for stability.

# `compare_to_jamba`

```elixir
@spec compare_to_jamba(keyword()) :: map()
```

Compare Zamba vs Jamba parameter counts.

Shows the parameter savings from using shared attention.

## Example

    iex> Zamba.compare_to_jamba(embed_dim: 287, num_layers: 6)
    %{
      zamba_params: 450_000,
      jamba_params: 600_000,
      savings: 150_000,
      savings_percent: 25.0
    }

# `layer_pattern`

```elixir
@spec layer_pattern(keyword()) :: [atom()]
```

Get the layer pattern for a given configuration.

Returns a list describing each layer type for debugging/visualization.
The shared attention is marked with "(shared)" to distinguish from Jamba.

## Example

    iex> Zamba.layer_pattern(num_layers: 6, attention_every: 3)
    [:mamba, :mamba, :mamba_attention, :mamba, :mamba, :mamba_attention]

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a Zamba model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a Zamba model.

Note: Zamba has fewer parameters than Jamba because attention weights
are shared (counted once, not per layer).

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for real-time sequence processing.

Optimized for:
- Real-time inference
- 1-second context window
- Minimal memory footprint (shared attention)

---

*Consult [api-reference.md](api-reference.md) for complete listing*
