# `Edifice.Recurrent`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/recurrent/recurrent.ex#L1)

Recurrent neural network layers for temporal sequence processing.

Provides LSTM and GRU architectures for learning temporal dependencies
in sequential data - essential for understanding:
- Multi-step action sequences
- Temporal patterns and trends
- Long-range dependencies
- Reactive decision sequences

## Architecture

The recurrent backbone processes sequences of embedded states:

```
Frame Sequence [batch, seq_len, embed_dim]
      │
      ▼
┌─────────────┐
│  LSTM/GRU   │ ←─ hidden state (h, c for LSTM)
│  Layer 1    │
└─────────────┘
      │
      ▼
┌─────────────┐
│  LSTM/GRU   │  (optional stacked layers)
│  Layer 2    │
└─────────────┘
      │
      ▼
Hidden Output [batch, hidden_size]
```

## Hidden State Management

For real-time inference, hidden states must be carried between frames:

    # Initialize hidden state
    hidden = Recurrent.initial_hidden(model, batch_size)

    # Process frame, get new hidden
    {output, new_hidden} = Recurrent.forward_with_state(model, params, frame, hidden)

    # Use new_hidden for next frame
    ...

## Usage

    # Build recurrent backbone
    model = Recurrent.build(
      embed_dim: 1024,
      hidden_size: 256,
      num_layers: 2,
      cell_type: :lstm,
      dropout: 0.1
    )

    # Use as backbone in policy network
    input = Axon.input("state_sequence", shape: {nil, nil, 1024})
    backbone = Recurrent.build_backbone(input, hidden_size: 256, cell_type: :gru)
    policy_head = build_policy_head(backbone)

# `build_opt`

```elixir
@type build_opt() ::
  {:cell_type, cell_type()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:return_sequences, boolean()}
  | {:seq_len, pos_integer()}
  | {:truncate_bptt, pos_integer() | nil}
  | {:window_size, pos_integer()}
```

Options for `build/1`.

# `cell_type`

```elixir
@type cell_type() :: :lstm | :gru
```

# `hidden_state`

```elixir
@type hidden_state() :: Nx.Tensor.t() | {Nx.Tensor.t(), Nx.Tensor.t()}
```

# `apply_gradient_truncation`

```elixir
@spec apply_gradient_truncation(Axon.t(), pos_integer()) :: Axon.t()
```

Apply gradient truncation to a sequence for truncated BPTT.

This creates an Axon layer that stops gradients from flowing back through
timesteps earlier than the last `keep_steps` frames.

## How it works
For a sequence of 60 frames with truncate_bptt=15:
- Forward pass: all 60 frames processed normally
- Backward pass: gradients only flow through the last 15 frames
- Earlier frames have their gradients stopped with Nx.stop_gradient

## Performance Impact
- ~2-3x faster training (less gradient computation)
- May reduce ability to learn very long-range dependencies
- Recommended: start with window_size/2 or window_size/3

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a recurrent model for sequence processing.

## Options
  - `:embed_dim` - Size of input embedding per frame (required)
  - `:hidden_size` - Size of recurrent hidden state (default: 256)
  - `:num_layers` - Number of stacked recurrent layers (default: 1)
  - `:cell_type` - :lstm or :gru (default: :lstm)
  - `:dropout` - Dropout rate between layers (default: 0.0)
  - `:bidirectional` - Use bidirectional processing (default: false)
  - `:return_sequences` - Return all timesteps or just last (default: false)

## Returns
  An Axon model that processes sequences and outputs hidden representations.

# `build_backbone`

```elixir
@spec build_backbone(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Build the recurrent backbone from an existing input layer.

Useful for integrating into larger networks (policy, value).

## Options
  - `:hidden_size` - Size of recurrent hidden state (default: 256)
  - `:num_layers` - Number of stacked recurrent layers (default: 1)
  - `:cell_type` - :lstm or :gru (default: :lstm)
  - `:dropout` - Dropout rate between layers (default: 0.0)
  - `:return_sequences` - Return all timesteps or just last (default: false)
  - `:truncate_bptt` - Truncate gradients to last N steps (default: nil = full BPTT)
                       Set to e.g. 15-20 for 2-3x faster training with some accuracy loss
  - `:input_layer_norm` - Apply layer norm to input for stability (default: true)
  - `:use_layer_norm` - Apply layer norm after each RNN layer (default: true)

# `build_hybrid`

```elixir
@spec build_hybrid(keyword()) :: Axon.t()
```

Build a hybrid recurrent-MLP backbone.

Combines recurrent layers for temporal processing with MLP layers
for non-linear transformation. This often works better than pure RNN.

```
Sequence [batch, seq_len, embed_dim]
      │
      ▼
┌─────────────┐
│  LSTM/GRU   │
│  Layers     │
└─────────────┘
      │
      ▼
[batch, hidden_size]
      │
      ▼
┌─────────────┐
│    MLP      │
│  Layers     │
└─────────────┘
      │
      ▼
[batch, output_size]
```

## Options
  - `:embed_dim` - Size of input embedding (required)
  - `:recurrent_size` - Size of recurrent hidden (default: 256)
  - `:mlp_sizes` - List of MLP layer sizes (default: [256])
  - `:cell_type` - :lstm or :gru (default: :lstm)
  - `:num_recurrent_layers` - Number of RNN layers (default: 1)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:activation` - MLP activation (default: :relu)

# `build_recurrent_layer`

```elixir
@spec build_recurrent_layer(Axon.t(), non_neg_integer(), cell_type(), keyword()) ::
  Axon.t()
```

Build a single recurrent layer (LSTM or GRU).

## Options
  - `:name` - Layer name prefix
  - `:return_sequences` - Whether to return all timesteps or just the last (default: true)
  - `:use_layer_norm` - Add layer normalization after RNN for stability (default: true)
  - `:recurrent_initializer` - Initializer for recurrent weights (default: :glorot_uniform)

## Stability Notes

RNNs are prone to gradient explosion/vanishing. This implementation uses:
1. **Orthogonal initialization** for recurrent weights (preserves gradient magnitude)
2. **Layer normalization** after each RNN layer (stabilizes hidden state magnitudes)
3. Standard glorot for input weights (via Axon defaults)

If training still diverges, reduce learning rate to 1e-5 and use gradient clipping 0.5.

# `build_stateful`

```elixir
@spec build_stateful(keyword()) :: Axon.t()
```

Build a stateful recurrent model that explicitly manages hidden state.

This is essential for real-time inference where we process one frame at a time
and need to carry hidden state between frames.

Returns a simple model that processes single frames. Hidden state management
is handled externally using `initial_hidden/2`.

## Options
  - `:embed_dim` - Size of input embedding (required)
  - `:hidden_size` - Size of hidden state (default: 256)
  - `:cell_type` - :lstm or :gru (default: :lstm)

## Returns
  An Axon model that takes single frames and outputs hidden representations.

# `cell_types`

```elixir
@spec cell_types() :: [cell_type()]
```

Get supported cell types.

# `frames_to_sequence`

```elixir
@spec frames_to_sequence([Nx.Tensor.t()]) :: Nx.Tensor.t()
```

Create a sequence from individual frames for batch processing.

Takes a list of embedded frames and stacks them into a sequence tensor.

# `initial_hidden`

```elixir
@spec initial_hidden(
  non_neg_integer(),
  keyword()
) :: hidden_state()
```

Create initial hidden state for a given batch size.

## Options
  - `:hidden_size` - Size of hidden state (default: 256)
  - `:cell_type` - :lstm or :gru (default: :lstm)

## Returns
  For LSTM: `{h, c}` tuple of zero tensors
  For GRU: single zero tensor

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Calculate the output size of a recurrent backbone.

# `pad_sequence`

```elixir
@spec pad_sequence(Nx.Tensor.t(), non_neg_integer(), keyword()) :: Nx.Tensor.t()
```

Pad or truncate sequence to fixed length.

Useful for batch processing sequences of different lengths.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
