# `Edifice.Recurrent.XLSTM`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/recurrent/xlstm.ex#L1)

xLSTM: Extended Long Short-Term Memory.

Implements the xLSTM architecture from "xLSTM: Extended Long Short-Term Memory"
(Beck et al., NeurIPS 2024).

## Key Innovations

xLSTM addresses three fundamental LSTM limitations:
1. Inability to revise storage decisions -> **Exponential gating**
2. Limited storage capacity -> **Matrix memory (mLSTM)**
3. Lack of parallelizability -> **mLSTM covariance update**

## Two Variants

### sLSTM (Scalar LSTM)
- Exponential gating: `i_t = exp(W_i x_t + R_i h_{t-1} + b_i)`
- Normalizer state prevents overflow: `n_t = f_t * n_{t-1} + i_t`
- Sequential processing with memory mixing
- Good for state-tracking tasks

### mLSTM (Matrix LSTM)
- Matrix memory cell: `C_t = f_t * C_{t-1} + i_t * (v_t k_t^T)`
- Key-value storage similar to attention
- Fully parallelizable during training
- Good for memorization tasks

## Architecture

```
Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|         xLSTM Block                  |
|  +----------------------------------+|
|  | Layer Norm -> sLSTM/mLSTM        ||
|  |       |                          ||
|  | Layer Norm -> Feedforward        ||
|  |       |                          ||
|  | Residual Connection             ||
|  +----------------------------------+|
+-------------------------------------+
      | (repeat for num_layers)
      v
Output [batch, hidden_size]
```

## Usage

    # sLSTM-only model (state tracking)
    model = XLSTM.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 4,
      variant: :slstm
    )

    # mLSTM-only model (memorization)
    model = XLSTM.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 4,
      variant: :mlstm
    )

    # Mixed model (default: alternating)
    model = XLSTM.build(
      embed_dim: 287,
      hidden_size: 256,
      num_layers: 6,
      variant: :mixed  # sLSTM at layers 1,3,5; mLSTM at 2,4,6
    )

## References
- Paper: https://arxiv.org/abs/2405.04517
- Official code: https://github.com/NX-AI/xlstm

# `build_opt`

```elixir
@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:variant, :slstm | :mlstm | :mixed}
  | {:num_heads, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build an xLSTM model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Internal hidden dimension (default: 256)
  - `:num_layers` - Number of xLSTM blocks (default: 4)
  - `:variant` - :slstm, :mlstm, or :mixed (default: :mixed)
  - `:num_heads` - Number of heads for mLSTM (default: 4)
  - `:head_dim` - Dimension per head for mLSTM (default: 64)
  - `:expand_factor` - Feedforward expansion factor (default: 2)
  - `:dropout` - Dropout rate (default: 0.0)
  - `:window_size` - Expected sequence length (default: 60)

## Returns
  An Axon model that processes sequences and outputs the last hidden state.

# `build_feedforward`

```elixir
@spec build_feedforward(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()
```

Build a feedforward layer with GeLU activation.

# `build_mlstm_layer`

```elixir
@spec build_mlstm_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the mLSTM (Matrix LSTM) layer.

mLSTM equations:
- i_t = exp(W_i x_t + b_i)                   # Input gate (exponential)
- f_t = exp(W_f x_t + b_f)                   # Forget gate (exponential)
- o_t = sigmoid(W_o x_t + b_o)               # Output gate (sigmoid)
- k_t = W_k x_t                              # Key projection
- v_t = W_v x_t                              # Value projection
- q_t = W_q x_t                              # Query projection
- C_t = f_t * C_{t-1} + i_t * (v_t k_t^T)   # Matrix memory
- n_t = f_t * n_{t-1} + i_t * k_t            # Normalizer
- h_t = o_t * (C_t q_t / max(q_t^T n_t, 1)) # Hidden state

The matrix memory C stores key-value associations like attention.

# `build_slstm_layer`

```elixir
@spec build_slstm_layer(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the sLSTM (Scalar LSTM) layer.

sLSTM equations with log-domain stabilization:
- log_i_t = W_i x_t + R_i h_{t-1} + b_i
- log_f_t = W_f x_t + R_f h_{t-1} + b_f
- z_t = tanh(W_z x_t + R_z h_{t-1} + b_z)
- o_t = sigmoid(W_o x_t + R_o h_{t-1} + b_o)

Log-domain stabilization (prevents exponential overflow):
- m_t = max(log_f_t + m_{t-1}, log_i_t)
- i_t' = exp(log_i_t - m_t)
- f_t' = exp(log_f_t + m_{t-1} - m_t)
- c_t = f_t' * c_{t-1} + i_t' * z_t
- n_t = f_t' * n_{t-1} + i_t'
- h_t = o_t * (c_t / max(|n_t|, 1))

The recurrent connections R_i, R_f, R_z, R_o enable memory mixing.

# `build_xlstm_block`

```elixir
@spec build_xlstm_block(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build a single xLSTM block.

xLSTM block structure:
1. LayerNorm -> sLSTM/mLSTM -> Residual
2. LayerNorm -> Feedforward -> Residual

# `default_dropout`

```elixir
@spec default_dropout() :: float()
```

Default dropout rate

# `default_expand_factor`

```elixir
@spec default_expand_factor() :: pos_integer()
```

Default feedforward expansion factor

# `default_head_dim`

```elixir
@spec default_head_dim() :: pos_integer()
```

Default head dimension for mLSTM

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default hidden dimension

# `default_num_heads`

```elixir
@spec default_num_heads() :: pos_integer()
```

Default number of heads for mLSTM

# `default_num_layers`

```elixir
@spec default_num_layers() :: pos_integer()
```

Default number of layers

# `gate_eps`

```elixir
@spec gate_eps() :: float()
```

Stabilization epsilon for exponential gating

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of an xLSTM model.

# `param_count`

```elixir
@spec param_count(keyword()) :: non_neg_integer()
```

Calculate approximate parameter count for an xLSTM model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults for sequence processing.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
