Minimal LSTM (MinLSTM) - A simplified LSTM that is parallel-scannable.
Implements the MinLSTM from "Were RNNs All We Needed?" (Feng et al., 2024). MinLSTM simplifies the LSTM by removing the output gate and hidden state nonlinearity, keeping only the forget and input gates with a normalization constraint f + i = 1.
Key Innovations
- Normalized gates: f_t + i_t = 1 (forget and input gates sum to 1)
- No output gate: Cell state IS the hidden state
- No hidden-to-hidden in gates: Gates depend only on input
- Parallel scannable: The normalized gating admits parallel prefix scan
Equations
f_t = sigmoid(linear_f(x_t)) # Forget gate
i_t = sigmoid(linear_i(x_t)) # Input gate
f'_t = f_t / (f_t + i_t) # Normalized forget
i'_t = i_t / (f_t + i_t) # Normalized input
candidate_t = linear_h(x_t) # Candidate value
c_t = f'_t * c_{t-1} + i'_t * candidate_t # Cell state = hidden stateArchitecture
Input [batch, seq_len, embed_dim]
|
v
[Input Projection] -> hidden_size
|
v
+---------------------------+
| MinLSTM Layer |
| f = sigmoid(W_f * x) |
| i = sigmoid(W_i * x) |
| f', i' = normalize(f,i) |
| c = W_h * x |
| h = f'*h + i'*c |
+---------------------------+
| (repeat num_layers)
v
[Layer Norm] -> [Last Timestep]
|
v
Output [batch, hidden_size]Usage
model = MinLSTM.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
dropout: 0.1
)References
Summary
Functions
Build a MinLSTM model for sequence processing.
Default dropout rate
Default hidden dimension
Default number of layers
Normalization epsilon
Get the output size of a MinLSTM model.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a MinLSTM model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of MinLSTM layers (default: 4):dropout- Dropout rate between layers (default: 0.1):window_size- Expected sequence length (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec default_dropout() :: float()
Default dropout rate
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec norm_eps() :: float()
Normalization epsilon
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a MinLSTM model.