Edifice.Recurrent (Edifice v0.2.0)

Copy Markdown View Source

Recurrent neural network layers for temporal sequence processing.

Provides LSTM and GRU architectures for learning temporal dependencies in sequential data - essential for understanding:

  • Multi-step action sequences
  • Temporal patterns and trends
  • Long-range dependencies
  • Reactive decision sequences

Architecture

The recurrent backbone processes sequences of embedded states:

Frame Sequence [batch, seq_len, embed_dim]
      
      

  LSTM/GRU     hidden state (h, c for LSTM)
  Layer 1    

      
      

  LSTM/GRU     (optional stacked layers)
  Layer 2    

      
      
Hidden Output [batch, hidden_size]

Hidden State Management

For real-time inference, hidden states must be carried between frames:

# Initialize hidden state
hidden = Recurrent.initial_hidden(model, batch_size)

# Process frame, get new hidden
{output, new_hidden} = Recurrent.forward_with_state(model, params, frame, hidden)

# Use new_hidden for next frame
...

Usage

# Build recurrent backbone
model = Recurrent.build(
  embed_dim: 1024,
  hidden_size: 256,
  num_layers: 2,
  cell_type: :lstm,
  dropout: 0.1
)

# Use as backbone in policy network
input = Axon.input("state_sequence", shape: {nil, nil, 1024})
backbone = Recurrent.build_backbone(input, hidden_size: 256, cell_type: :gru)
policy_head = build_policy_head(backbone)

Summary

Functions

Apply gradient truncation to a sequence for truncated BPTT.

Build a recurrent model for sequence processing.

Build the recurrent backbone from an existing input layer.

Build a hybrid recurrent-MLP backbone.

Build a single recurrent layer (LSTM or GRU).

Build a stateful recurrent model that explicitly manages hidden state.

Get supported cell types.

Create a sequence from individual frames for batch processing.

Create initial hidden state for a given batch size.

Calculate the output size of a recurrent backbone.

Pad or truncate sequence to fixed length.

Types

build_opt()

@type build_opt() ::
  {:cell_type, cell_type()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:return_sequences, boolean()}
  | {:seq_len, pos_integer()}
  | {:truncate_bptt, pos_integer() | nil}
  | {:window_size, pos_integer()}

Options for build/1.

cell_type()

@type cell_type() :: :lstm | :gru

hidden_state()

@type hidden_state() :: Nx.Tensor.t() | {Nx.Tensor.t(), Nx.Tensor.t()}

Functions

apply_gradient_truncation(input, keep_steps)

@spec apply_gradient_truncation(Axon.t(), pos_integer()) :: Axon.t()

Apply gradient truncation to a sequence for truncated BPTT.

This creates an Axon layer that stops gradients from flowing back through timesteps earlier than the last keep_steps frames.

How it works

For a sequence of 60 frames with truncate_bptt=15:

  • Forward pass: all 60 frames processed normally
  • Backward pass: gradients only flow through the last 15 frames
  • Earlier frames have their gradients stopped with Nx.stop_gradient

Performance Impact

  • ~2-3x faster training (less gradient computation)
  • May reduce ability to learn very long-range dependencies
  • Recommended: start with window_size/2 or window_size/3

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a recurrent model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Size of recurrent hidden state (default: 256)
  • :num_layers - Number of stacked recurrent layers (default: 1)
  • :cell_type - :lstm or :gru (default: :lstm)
  • :dropout - Dropout rate between layers (default: 0.0)
  • :bidirectional - Use bidirectional processing (default: false)
  • :return_sequences - Return all timesteps or just last (default: false)

Returns

An Axon model that processes sequences and outputs hidden representations.

build_backbone(input, opts \\ [])

@spec build_backbone(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the recurrent backbone from an existing input layer.

Useful for integrating into larger networks (policy, value).

Options

  • :hidden_size - Size of recurrent hidden state (default: 256)
  • :num_layers - Number of stacked recurrent layers (default: 1)
  • :cell_type - :lstm or :gru (default: :lstm)
  • :dropout - Dropout rate between layers (default: 0.0)
  • :return_sequences - Return all timesteps or just last (default: false)
  • :truncate_bptt - Truncate gradients to last N steps (default: nil = full BPTT)
                   Set to e.g. 15-20 for 2-3x faster training with some accuracy loss
  • :input_layer_norm - Apply layer norm to input for stability (default: true)
  • :use_layer_norm - Apply layer norm after each RNN layer (default: true)

build_hybrid(opts \\ [])

@spec build_hybrid(keyword()) :: Axon.t()

Build a hybrid recurrent-MLP backbone.

Combines recurrent layers for temporal processing with MLP layers for non-linear transformation. This often works better than pure RNN.

Sequence [batch, seq_len, embed_dim]
      
      

  LSTM/GRU   
  Layers     

      
      
[batch, hidden_size]
      
      

    MLP      
  Layers     

      
      
[batch, output_size]

Options

  • :embed_dim - Size of input embedding (required)
  • :recurrent_size - Size of recurrent hidden (default: 256)
  • :mlp_sizes - List of MLP layer sizes (default: [256])
  • :cell_type - :lstm or :gru (default: :lstm)
  • :num_recurrent_layers - Number of RNN layers (default: 1)
  • :dropout - Dropout rate (default: 0.1)
  • :activation - MLP activation (default: :relu)

build_recurrent_layer(input, hidden_size, cell_type, opts \\ [])

@spec build_recurrent_layer(Axon.t(), non_neg_integer(), cell_type(), keyword()) ::
  Axon.t()

Build a single recurrent layer (LSTM or GRU).

Options

  • :name - Layer name prefix
  • :return_sequences - Whether to return all timesteps or just the last (default: true)
  • :use_layer_norm - Add layer normalization after RNN for stability (default: true)
  • :recurrent_initializer - Initializer for recurrent weights (default: :glorot_uniform)

Stability Notes

RNNs are prone to gradient explosion/vanishing. This implementation uses:

  1. Orthogonal initialization for recurrent weights (preserves gradient magnitude)
  2. Layer normalization after each RNN layer (stabilizes hidden state magnitudes)
  3. Standard glorot for input weights (via Axon defaults)

If training still diverges, reduce learning rate to 1e-5 and use gradient clipping 0.5.

build_stateful(opts \\ [])

@spec build_stateful(keyword()) :: Axon.t()

Build a stateful recurrent model that explicitly manages hidden state.

This is essential for real-time inference where we process one frame at a time and need to carry hidden state between frames.

Returns a simple model that processes single frames. Hidden state management is handled externally using initial_hidden/2.

Options

  • :embed_dim - Size of input embedding (required)
  • :hidden_size - Size of hidden state (default: 256)
  • :cell_type - :lstm or :gru (default: :lstm)

Returns

An Axon model that takes single frames and outputs hidden representations.

cell_types()

@spec cell_types() :: [cell_type()]

Get supported cell types.

frames_to_sequence(frames)

@spec frames_to_sequence([Nx.Tensor.t()]) :: Nx.Tensor.t()

Create a sequence from individual frames for batch processing.

Takes a list of embedded frames and stacks them into a sequence tensor.

initial_hidden(batch_size, opts \\ [])

@spec initial_hidden(
  non_neg_integer(),
  keyword()
) :: hidden_state()

Create initial hidden state for a given batch size.

Options

  • :hidden_size - Size of hidden state (default: 256)
  • :cell_type - :lstm or :gru (default: :lstm)

Returns

For LSTM: {h, c} tuple of zero tensors For GRU: single zero tensor

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Calculate the output size of a recurrent backbone.

pad_sequence(sequence, target_length, opts \\ [])

@spec pad_sequence(Nx.Tensor.t(), non_neg_integer(), keyword()) :: Nx.Tensor.t()

Pad or truncate sequence to fixed length.

Useful for batch processing sequences of different lengths.