Edifice.Recurrent.XLSTM (Edifice v0.2.0)

Copy Markdown View Source

xLSTM: Extended Long Short-Term Memory.

Implements the xLSTM architecture from "xLSTM: Extended Long Short-Term Memory" (Beck et al., NeurIPS 2024).

Key Innovations

xLSTM addresses three fundamental LSTM limitations:

  1. Inability to revise storage decisions -> Exponential gating
  2. Limited storage capacity -> Matrix memory (mLSTM)
  3. Lack of parallelizability -> mLSTM covariance update

Two Variants

sLSTM (Scalar LSTM)

  • Exponential gating: i_t = exp(W_i x_t + R_i h_{t-1} + b_i)
  • Normalizer state prevents overflow: n_t = f_t * n_{t-1} + i_t
  • Sequential processing with memory mixing
  • Good for state-tracking tasks

mLSTM (Matrix LSTM)

  • Matrix memory cell: C_t = f_t * C_{t-1} + i_t * (v_t k_t^T)
  • Key-value storage similar to attention
  • Fully parallelizable during training
  • Good for memorization tasks

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|         xLSTM Block                  |
|  +----------------------------------+|
|  | Layer Norm -> sLSTM/mLSTM        ||
|  |       |                          ||
|  | Layer Norm -> Feedforward        ||
|  |       |                          ||
|  | Residual Connection             ||
|  +----------------------------------+|
+-------------------------------------+
      | (repeat for num_layers)
      v
Output [batch, hidden_size]

Usage

# sLSTM-only model (state tracking)
model = XLSTM.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  variant: :slstm
)

# mLSTM-only model (memorization)
model = XLSTM.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  variant: :mlstm
)

# Mixed model (default: alternating)
model = XLSTM.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 6,
  variant: :mixed  # sLSTM at layers 1,3,5; mLSTM at 2,4,6
)

References

Summary

Types

Options for build/1.

Functions

Build an xLSTM model for sequence processing.

Build a feedforward layer with GeLU activation.

Build the mLSTM (Matrix LSTM) layer.

Build the sLSTM (Scalar LSTM) layer.

Build a single xLSTM block.

Default dropout rate

Default feedforward expansion factor

Default head dimension for mLSTM

Default hidden dimension

Default number of heads for mLSTM

Default number of layers

Stabilization epsilon for exponential gating

Get the output size of an xLSTM model.

Calculate approximate parameter count for an xLSTM model.

Get recommended defaults for sequence processing.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:variant, :slstm | :mlstm | :mixed}
  | {:num_heads, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an xLSTM model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of xLSTM blocks (default: 4)
  • :variant - :slstm, :mlstm, or :mixed (default: :mixed)
  • :num_heads - Number of heads for mLSTM (default: 4)
  • :head_dim - Dimension per head for mLSTM (default: 64)
  • :expand_factor - Feedforward expansion factor (default: 2)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_feedforward(input, hidden_size, expand_factor, name)

@spec build_feedforward(Axon.t(), pos_integer(), pos_integer(), String.t()) ::
  Axon.t()

Build a feedforward layer with GeLU activation.

build_mlstm_layer(input, opts \\ [])

@spec build_mlstm_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the mLSTM (Matrix LSTM) layer.

mLSTM equations:

  • i_t = exp(W_i x_t + b_i) # Input gate (exponential)
  • f_t = exp(W_f x_t + b_f) # Forget gate (exponential)
  • o_t = sigmoid(W_o x_t + b_o) # Output gate (sigmoid)
  • k_t = W_k x_t # Key projection
  • v_t = W_v x_t # Value projection
  • q_t = W_q x_t # Query projection
  • Ct = f_t * C{t-1} + i_t * (v_t k_t^T) # Matrix memory
  • nt = f_t * n{t-1} + i_t * k_t # Normalizer
  • h_t = o_t * (C_t q_t / max(q_t^T n_t, 1)) # Hidden state

The matrix memory C stores key-value associations like attention.

build_slstm_layer(input, opts \\ [])

@spec build_slstm_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the sLSTM (Scalar LSTM) layer.

sLSTM equations with log-domain stabilization:

  • logi_t = W_i x_t + R_i h{t-1} + b_i
  • logf_t = W_f x_t + R_f h{t-1} + b_f
  • zt = tanh(W_z x_t + R_z h{t-1} + b_z)
  • ot = sigmoid(W_o x_t + R_o h{t-1} + b_o)

Log-domain stabilization (prevents exponential overflow):

  • mt = max(log_f_t + m{t-1}, log_i_t)
  • i_t' = exp(log_i_t - m_t)
  • ft' = exp(log_f_t + m{t-1} - m_t)
  • ct = f_t' * c{t-1} + i_t' * z_t
  • nt = f_t' * n{t-1} + i_t'
  • h_t = o_t * (c_t / max(|n_t|, 1))

The recurrent connections R_i, R_f, R_z, R_o enable memory mixing.

build_xlstm_block(input, opts \\ [])

@spec build_xlstm_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single xLSTM block.

xLSTM block structure:

  1. LayerNorm -> sLSTM/mLSTM -> Residual
  2. LayerNorm -> Feedforward -> Residual

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_expand_factor()

@spec default_expand_factor() :: pos_integer()

Default feedforward expansion factor

default_head_dim()

@spec default_head_dim() :: pos_integer()

Default head dimension for mLSTM

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_heads()

@spec default_num_heads() :: pos_integer()

Default number of heads for mLSTM

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

gate_eps()

@spec gate_eps() :: float()

Stabilization epsilon for exponential gating

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an xLSTM model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for an xLSTM model.