# `Edifice.SSM.Hybrid`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/ssm/hybrid.ex#L1)

Configurable Hybrid Backbone+Attention architecture for efficient sequence modeling.

Originally based on AI21's Jamba architecture, this module interleaves a
configurable backbone (Mamba by default) with periodic attention layers.
The backbone handles local/sequential context efficiently, while attention
captures long-range dependencies.

## Supported Backbones

| Backbone | Module | Key Mechanism |
|----------|--------|---------------|
| `:mamba` (default) | GatedSSM | Selective state space model |
| `:gru` | Axon.gru | Gated recurrent unit |
| `:rwkv` | RWKV | Linear attention (WKV) |
| `:delta_net` | DeltaNet | Delta rule linear attention |
| `:gated_delta_net` | GatedDeltaNet | Gated delta rule |
| `:griffin_lru` | Griffin RG-LRU | Real-gated linear recurrent unit |
| `:custom` | User-provided | Any `(Axon.t(), keyword()) -> Axon.t()` |

## Architecture Pattern

```
Input [batch, seq_len, embed_dim]
      │
      ▼
┌─────────────────────────────────────┐
│  Backbone Block 1                    │
├─────────────────────────────────────┤
│  Backbone Block 2                    │
├─────────────────────────────────────┤
│  Backbone Block 3                    │
├─────────────────────────────────────┤
│  Attention Block (every N layers)   │  <- Long-range dependencies
├─────────────────────────────────────┤
│  Backbone Block 4                    │
├─────────────────────────────────────┤
│  ...                                 │
└─────────────────────────────────────┘
      │
      ▼
[batch, hidden_size]
```

## Key Advantages

1. **Efficiency**: Most layers are O(L) backbone blocks
2. **Long-range**: Periodic attention captures distant dependencies
3. **Flexible**: Swap backbone without changing the hybrid structure
4. **Memory**: Far less than pure attention, slightly more than pure backbone

## Usage

    # Default hybrid (Mamba backbone, 3:1 ratio)
    model = Hybrid.build(
      embed_dim: 256,
      hidden_size: 256,
      num_layers: 8,
      attention_every: 4
    )

    # GRU backbone (classic RNN + attention)
    model = Hybrid.build(
      embed_dim: 256,
      num_layers: 6,
      backbone: :gru,
      attention_every: 3
    )

    # Gated DeltaNet backbone (linear attention + full attention)
    model = Hybrid.build(
      embed_dim: 256,
      num_layers: 6,
      backbone: :gated_delta_net,
      attention_every: 3
    )

# `build_opt`

```elixir
@type build_opt() ::
  {:attention_every, pos_integer()}
  | {:backbone, atom() | {module(), atom()}}
  | {:chunk_size, pos_integer()}
  | {:chunked_attention, boolean()}
  | {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:memory_efficient_attention, boolean()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:pre_norm, boolean()}
  | {:qk_layernorm, boolean()}
  | {:seq_len, pos_integer()}
  | {:state_size, pos_integer()}
  | {:use_sliding_window, boolean()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a hybrid backbone+attention model.

## Options

**Architecture:**
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_layers` - Total number of layers (default: 6)
  - `:attention_every` - Insert attention every N layers (default: 3)
  - `:backbone` - Backbone type for non-attention layers (default: `:mamba`).
    Supported: `:mamba`, `:gru`, `:rwkv`, `:delta_net`, `:gated_delta_net`,
    `:griffin_lru`, or a `{module, function}` tuple for custom backbones.
    Custom functions must accept `(input :: Axon.t(), opts :: keyword()) :: Axon.t()`.

**Mamba-specific (when backbone: :mamba):**
  - `:state_size` - SSM state dimension (default: 16)
  - `:expand_factor` - Mamba expansion factor (default: 2)
  - `:conv_size` - Causal conv kernel size (default: 4)

**Attention-specific:**
  - `:num_heads` - Number of attention heads (default: 4)
  - `:head_dim` - Dimension per attention head (default: 64)
  - `:window_size` - Attention window size (default: 60)
  - `:use_sliding_window` - Use sliding window vs full attention (default: true)
  - `:qk_layernorm` - Normalize Q and K before attention (default: true, stabilizes training)

**General:**
  - `:dropout` - Dropout rate (default: 0.1)
  - `:seq_len` - Fixed sequence length for JIT optimization (default: window_size)
  - `:pre_norm` - Use Pre-LayerNorm (default: true, more stable than Post-LayerNorm)

## Returns

  An Axon model that outputs [batch, hidden_size] from the last position.

## Examples

    # Mamba backbone (default, Jamba-style)
    model = Hybrid.build(
      embed_dim: 256,
      hidden_size: 256,
      num_layers: 6,
      attention_every: 3
    )

    # GRU backbone
    model = Hybrid.build(
      embed_dim: 256,
      num_layers: 6,
      backbone: :gru,
      attention_every: 3
    )

# `build_attention_layer`

```elixir
@spec build_attention_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build an attention layer with residual connection and FFN.

## Options
  - `:pre_norm` - If true, apply LayerNorm before block (Pre-LN, more stable).
  - `:qk_layernorm` - If true, normalize Q and K before attention (stabilizes training).
  - `:chunked` - If true, use chunked attention for lower memory usage (default: false).
  - `:memory_efficient` - If true, use memory-efficient attention with online softmax for true O(n) memory (default: false).
  - `:chunk_size` - Chunk size when using chunked or memory-efficient attention (default: 32).

# `build_backbone_layer`

```elixir
@spec build_backbone_layer(Axon.t(), atom() | {module(), atom()}, keyword()) ::
  Axon.t()
```

Build a backbone layer based on the configured backbone type.

Dispatches to the appropriate block builder. All backbone types follow
the same contract: `(Axon.t(), keyword()) -> Axon.t()`, taking
[batch, seq, hidden] input and returning the same shape with a residual connection.

## Supported Backbones

  - `:mamba` - GatedSSM Mamba block (default)
  - `:gru` - GRU recurrent layer
  - `:rwkv` - RWKV linear attention block
  - `:delta_net` - DeltaNet delta rule block
  - `:gated_delta_net` - Gated DeltaNet block
  - `:griffin_lru` - Griffin RG-LRU block
  - `{module, function}` - Custom backbone; called as `module.function(input, opts)`

# `build_delta_net_layer`

```elixir
@spec build_delta_net_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a DeltaNet backbone layer. Delegates to `DeltaNet.build_block/2`
which includes pre-norm and residual connection.

# `build_gated_delta_net_layer`

```elixir
@spec build_gated_delta_net_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a Gated DeltaNet backbone layer with pre-norm and residual connection.

# `build_gru_layer`

```elixir
@spec build_gru_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a GRU backbone layer with pre-norm and residual connection.

Wraps Axon.gru in the same pre-norm + residual pattern as Mamba layers.

# `build_mamba_layer`

```elixir
@spec build_mamba_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a Mamba layer with residual connection.

## Options
  - `:pre_norm` - If true, apply LayerNorm before block (Pre-LN, more stable).
                  If false, apply after residual (Post-LN, original transformer style).

# `build_rwkv_layer`

```elixir
@spec build_rwkv_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build an RWKV backbone layer with pre-norm and residual connection.

# `layer_pattern`

```elixir
@spec layer_pattern(keyword()) :: [atom()]
```

Get the layer pattern for a given configuration.

Returns a list describing each layer type for debugging/visualization.

## Examples

    iex> Hybrid.layer_pattern(num_layers: 6, attention_every: 3)
    [:mamba, :mamba, :attention, :mamba, :mamba, :attention]

    iex> Hybrid.layer_pattern(num_layers: 6, attention_every: 3, backbone: :gru)
    [:gru, :gru, :attention, :gru, :gru, :attention]

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of a hybrid model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for a hybrid model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for real-time sequence processing.

Optimized for:
- Real-time inference (~10ms budget)
- 1-second context window
- Balance between local patterns (Mamba) and long-range context (Attention)
- Training stability (pre-norm + QK LayerNorm)

---

*Consult [api-reference.md](api-reference.md) for complete listing*
