xLSTM: Extended Long Short-Term Memory.
Implements the xLSTM architecture from "xLSTM: Extended Long Short-Term Memory" (Beck et al., NeurIPS 2024).
Key Innovations
xLSTM addresses three fundamental LSTM limitations:
- Inability to revise storage decisions -> Exponential gating
- Limited storage capacity -> Matrix memory (mLSTM)
- Lack of parallelizability -> mLSTM covariance update
Two Variants
sLSTM (Scalar LSTM)
- Exponential gating:
i_t = exp(W_i x_t + R_i h_{t-1} + b_i) - Normalizer state prevents overflow:
n_t = f_t * n_{t-1} + i_t - Sequential processing with memory mixing
- Good for state-tracking tasks
mLSTM (Matrix LSTM)
- Matrix memory cell:
C_t = f_t * C_{t-1} + i_t * (v_t k_t^T) - Key-value storage similar to attention
- Fully parallelizable during training
- Good for memorization tasks
Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| xLSTM Block |
| +----------------------------------+|
| | Layer Norm -> sLSTM/mLSTM ||
| | | ||
| | Layer Norm -> Feedforward ||
| | | ||
| | Residual Connection ||
| +----------------------------------+|
+-------------------------------------+
| (repeat for num_layers)
v
Output [batch, hidden_size]Usage
# sLSTM-only model (state tracking)
model = XLSTM.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
variant: :slstm
)
# mLSTM-only model (memorization)
model = XLSTM.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
variant: :mlstm
)
# Mixed model (default: alternating)
model = XLSTM.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 6,
variant: :mixed # sLSTM at layers 1,3,5; mLSTM at 2,4,6
)References
- Paper: https://arxiv.org/abs/2405.04517
- Official code: https://github.com/NX-AI/xlstm
Summary
Functions
Build an xLSTM model for sequence processing.
Build a feedforward layer with GeLU activation.
Build the mLSTM (Matrix LSTM) layer.
Build the sLSTM (Scalar LSTM) layer.
Build a single xLSTM block.
Default dropout rate
Default feedforward expansion factor
Default head dimension for mLSTM
Default hidden dimension
Default number of heads for mLSTM
Default number of layers
Stabilization epsilon for exponential gating
Get the output size of an xLSTM model.
Calculate approximate parameter count for an xLSTM model.
Get recommended defaults for sequence processing.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:variant, :slstm | :mlstm | :mixed} | {:num_heads, pos_integer()} | {:head_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build an xLSTM model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of xLSTM blocks (default: 4):variant- :slstm, :mlstm, or :mixed (default: :mixed):num_heads- Number of heads for mLSTM (default: 4):head_dim- Dimension per head for mLSTM (default: 64):expand_factor- Feedforward expansion factor (default: 2):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec build_feedforward(Axon.t(), pos_integer(), pos_integer(), String.t()) :: Axon.t()
Build a feedforward layer with GeLU activation.
Build the mLSTM (Matrix LSTM) layer.
mLSTM equations:
- i_t = exp(W_i x_t + b_i) # Input gate (exponential)
- f_t = exp(W_f x_t + b_f) # Forget gate (exponential)
- o_t = sigmoid(W_o x_t + b_o) # Output gate (sigmoid)
- k_t = W_k x_t # Key projection
- v_t = W_v x_t # Value projection
- q_t = W_q x_t # Query projection
- Ct = f_t * C{t-1} + i_t * (v_t k_t^T) # Matrix memory
- nt = f_t * n{t-1} + i_t * k_t # Normalizer
- h_t = o_t * (C_t q_t / max(q_t^T n_t, 1)) # Hidden state
The matrix memory C stores key-value associations like attention.
Build the sLSTM (Scalar LSTM) layer.
sLSTM equations with log-domain stabilization:
- logi_t = W_i x_t + R_i h{t-1} + b_i
- logf_t = W_f x_t + R_f h{t-1} + b_f
- zt = tanh(W_z x_t + R_z h{t-1} + b_z)
- ot = sigmoid(W_o x_t + R_o h{t-1} + b_o)
Log-domain stabilization (prevents exponential overflow):
- mt = max(log_f_t + m{t-1}, log_i_t)
- i_t' = exp(log_i_t - m_t)
- ft' = exp(log_f_t + m{t-1} - m_t)
- ct = f_t' * c{t-1} + i_t' * z_t
- nt = f_t' * n{t-1} + i_t'
- h_t = o_t * (c_t / max(|n_t|, 1))
The recurrent connections R_i, R_f, R_z, R_o enable memory mixing.
Build a single xLSTM block.
xLSTM block structure:
- LayerNorm -> sLSTM/mLSTM -> Residual
- LayerNorm -> Feedforward -> Residual
@spec default_dropout() :: float()
Default dropout rate
@spec default_expand_factor() :: pos_integer()
Default feedforward expansion factor
@spec default_head_dim() :: pos_integer()
Default head dimension for mLSTM
@spec default_num_heads() :: pos_integer()
Default number of heads for mLSTM
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec gate_eps() :: float()
Stabilization epsilon for exponential gating
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of an xLSTM model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for an xLSTM model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for sequence processing.