Edifice.Recurrent.MinLSTM (Edifice v0.2.0)

Copy Markdown View Source

Minimal LSTM (MinLSTM) - A simplified LSTM that is parallel-scannable.

Implements the MinLSTM from "Were RNNs All We Needed?" (Feng et al., 2024). MinLSTM simplifies the LSTM by removing the output gate and hidden state nonlinearity, keeping only the forget and input gates with a normalization constraint f + i = 1.

Key Innovations

  • Normalized gates: f_t + i_t = 1 (forget and input gates sum to 1)
  • No output gate: Cell state IS the hidden state
  • No hidden-to-hidden in gates: Gates depend only on input
  • Parallel scannable: The normalized gating admits parallel prefix scan

Equations

f_t = sigmoid(linear_f(x_t))           # Forget gate
i_t = sigmoid(linear_i(x_t))           # Input gate
f'_t = f_t / (f_t + i_t)               # Normalized forget
i'_t = i_t / (f_t + i_t)               # Normalized input
candidate_t = linear_h(x_t)            # Candidate value
c_t = f'_t * c_{t-1} + i'_t * candidate_t  # Cell state = hidden state

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
[Input Projection] -> hidden_size
      |
      v
+---------------------------+
|     MinLSTM Layer         |
|  f = sigmoid(W_f * x)    |
|  i = sigmoid(W_i * x)    |
|  f', i' = normalize(f,i) |
|  c = W_h * x             |
|  h = f'*h + i'*c         |
+---------------------------+
      | (repeat num_layers)
      v
[Layer Norm] -> [Last Timestep]
      |
      v
Output [batch, hidden_size]

Usage

model = MinLSTM.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  dropout: 0.1
)

References

Summary

Types

Options for build/1.

Functions

Build a MinLSTM model for sequence processing.

Default dropout rate

Default hidden dimension

Default number of layers

Normalization epsilon

Get the output size of a MinLSTM model.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a MinLSTM model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of MinLSTM layers (default: 4)
  • :dropout - Dropout rate between layers (default: 0.1)
  • :window_size - Expected sequence length (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

norm_eps()

@spec norm_eps() :: float()

Normalization epsilon

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a MinLSTM model.