Recurrent neural network layers for temporal sequence processing.
Provides LSTM and GRU architectures for learning temporal dependencies in sequential data - essential for understanding:
- Multi-step action sequences
- Temporal patterns and trends
- Long-range dependencies
- Reactive decision sequences
Architecture
The recurrent backbone processes sequences of embedded states:
Frame Sequence [batch, seq_len, embed_dim]
│
▼
┌─────────────┐
│ LSTM/GRU │ ←─ hidden state (h, c for LSTM)
│ Layer 1 │
└─────────────┘
│
▼
┌─────────────┐
│ LSTM/GRU │ (optional stacked layers)
│ Layer 2 │
└─────────────┘
│
▼
Hidden Output [batch, hidden_size]Hidden State Management
For real-time inference, hidden states must be carried between frames:
# Initialize hidden state
hidden = Recurrent.initial_hidden(model, batch_size)
# Process frame, get new hidden
{output, new_hidden} = Recurrent.forward_with_state(model, params, frame, hidden)
# Use new_hidden for next frame
...Usage
# Build recurrent backbone
model = Recurrent.build(
embed_dim: 1024,
hidden_size: 256,
num_layers: 2,
cell_type: :lstm,
dropout: 0.1
)
# Use as backbone in policy network
input = Axon.input("state_sequence", shape: {nil, nil, 1024})
backbone = Recurrent.build_backbone(input, hidden_size: 256, cell_type: :gru)
policy_head = build_policy_head(backbone)
Summary
Functions
Apply gradient truncation to a sequence for truncated BPTT.
Build a recurrent model for sequence processing.
Build the recurrent backbone from an existing input layer.
Build a hybrid recurrent-MLP backbone.
Build a single recurrent layer (LSTM or GRU).
Build a stateful recurrent model that explicitly manages hidden state.
Get supported cell types.
Create a sequence from individual frames for batch processing.
Create initial hidden state for a given batch size.
Calculate the output size of a recurrent backbone.
Pad or truncate sequence to fixed length.
Types
@type build_opt() :: {:cell_type, cell_type()} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:return_sequences, boolean()} | {:seq_len, pos_integer()} | {:truncate_bptt, pos_integer() | nil} | {:window_size, pos_integer()}
Options for build/1.
@type cell_type() :: :lstm | :gru
Functions
@spec apply_gradient_truncation(Axon.t(), pos_integer()) :: Axon.t()
Apply gradient truncation to a sequence for truncated BPTT.
This creates an Axon layer that stops gradients from flowing back through
timesteps earlier than the last keep_steps frames.
How it works
For a sequence of 60 frames with truncate_bptt=15:
- Forward pass: all 60 frames processed normally
- Backward pass: gradients only flow through the last 15 frames
- Earlier frames have their gradients stopped with Nx.stop_gradient
Performance Impact
- ~2-3x faster training (less gradient computation)
- May reduce ability to learn very long-range dependencies
- Recommended: start with window_size/2 or window_size/3
Build a recurrent model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Size of recurrent hidden state (default: 256):num_layers- Number of stacked recurrent layers (default: 1):cell_type- :lstm or :gru (default: :lstm):dropout- Dropout rate between layers (default: 0.0):bidirectional- Use bidirectional processing (default: false):return_sequences- Return all timesteps or just last (default: false)
Returns
An Axon model that processes sequences and outputs hidden representations.
Build the recurrent backbone from an existing input layer.
Useful for integrating into larger networks (policy, value).
Options
:hidden_size- Size of recurrent hidden state (default: 256):num_layers- Number of stacked recurrent layers (default: 1):cell_type- :lstm or :gru (default: :lstm):dropout- Dropout rate between layers (default: 0.0):return_sequences- Return all timesteps or just last (default: false):truncate_bptt- Truncate gradients to last N steps (default: nil = full BPTT)Set to e.g. 15-20 for 2-3x faster training with some accuracy loss:input_layer_norm- Apply layer norm to input for stability (default: true):use_layer_norm- Apply layer norm after each RNN layer (default: true)
Build a hybrid recurrent-MLP backbone.
Combines recurrent layers for temporal processing with MLP layers for non-linear transformation. This often works better than pure RNN.
Sequence [batch, seq_len, embed_dim]
│
▼
┌─────────────┐
│ LSTM/GRU │
│ Layers │
└─────────────┘
│
▼
[batch, hidden_size]
│
▼
┌─────────────┐
│ MLP │
│ Layers │
└─────────────┘
│
▼
[batch, output_size]Options
:embed_dim- Size of input embedding (required):recurrent_size- Size of recurrent hidden (default: 256):mlp_sizes- List of MLP layer sizes (default: [256]):cell_type- :lstm or :gru (default: :lstm):num_recurrent_layers- Number of RNN layers (default: 1):dropout- Dropout rate (default: 0.1):activation- MLP activation (default: :relu)
@spec build_recurrent_layer(Axon.t(), non_neg_integer(), cell_type(), keyword()) :: Axon.t()
Build a single recurrent layer (LSTM or GRU).
Options
:name- Layer name prefix:return_sequences- Whether to return all timesteps or just the last (default: true):use_layer_norm- Add layer normalization after RNN for stability (default: true):recurrent_initializer- Initializer for recurrent weights (default: :glorot_uniform)
Stability Notes
RNNs are prone to gradient explosion/vanishing. This implementation uses:
- Orthogonal initialization for recurrent weights (preserves gradient magnitude)
- Layer normalization after each RNN layer (stabilizes hidden state magnitudes)
- Standard glorot for input weights (via Axon defaults)
If training still diverges, reduce learning rate to 1e-5 and use gradient clipping 0.5.
Build a stateful recurrent model that explicitly manages hidden state.
This is essential for real-time inference where we process one frame at a time and need to carry hidden state between frames.
Returns a simple model that processes single frames. Hidden state management
is handled externally using initial_hidden/2.
Options
:embed_dim- Size of input embedding (required):hidden_size- Size of hidden state (default: 256):cell_type- :lstm or :gru (default: :lstm)
Returns
An Axon model that takes single frames and outputs hidden representations.
@spec cell_types() :: [cell_type()]
Get supported cell types.
@spec frames_to_sequence([Nx.Tensor.t()]) :: Nx.Tensor.t()
Create a sequence from individual frames for batch processing.
Takes a list of embedded frames and stacks them into a sequence tensor.
@spec output_size(keyword()) :: non_neg_integer()
Calculate the output size of a recurrent backbone.
@spec pad_sequence(Nx.Tensor.t(), non_neg_integer(), keyword()) :: Nx.Tensor.t()
Pad or truncate sequence to fixed length.
Useful for batch processing sequences of different lengths.