Minimal GRU (MinGRU) - A simplified GRU with a single gate.
Implements the MinGRU from "Were RNNs All We Needed?" (Feng et al., 2024). MinGRU strips the GRU down to its essential component: a single forget/update gate. This makes it parallel-scannable during training while preserving the core gating mechanism that makes GRUs effective.
Key Innovations
- Single gate: Only one gate
z_tcontrols interpolation (vs 3 in standard GRU) - No hidden-to-hidden: Gate depends only on input, not previous hidden state
- Parallel scannable: The simplified recurrence admits a parallel prefix scan
- ~30 lines of core logic: Drastically simpler than standard GRU
Equations
z_t = sigmoid(linear_z(x_t)) # Update gate (input-only)
candidate_t = linear_h(x_t) # Candidate (no hidden dependency)
h_t = (1 - z_t) * h_{t-1} + z_t * candidate_t # InterpolationArchitecture
Input [batch, seq_len, embed_dim]
|
v
[Input Projection] -> hidden_size
|
v
+---------------------------+
| MinGRU Layer |
| z = sigmoid(W_z * x) |
| c = W_h * x |
| h = (1-z)*h + z*c |
+---------------------------+
| (repeat num_layers)
v
[Layer Norm] -> [Last Timestep]
|
v
Output [batch, hidden_size]Usage
model = MinGRU.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
dropout: 0.1
)References
Summary
Functions
Build a MinGRU model for sequence processing.
Default dropout rate
Default hidden dimension
Default number of layers
Get the output size of a MinGRU model.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a MinGRU model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of MinGRU layers (default: 4):dropout- Dropout rate between layers (default: 0.1):window_size- Expected sequence length (default: 60)
Returns
An Axon model that processes sequences and outputs the last hidden state.
@spec default_dropout() :: float()
Default dropout rate
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a MinGRU model.