# `Edifice.Contrastive.TemporalJEPA`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/contrastive/temporal_jepa.ex#L1)

Temporal JEPA — Joint Embedding Predictive Architecture for sequences.

Extends JEPA to temporal data (video frames, time series, trajectories).
The context encoder processes visible timesteps with bidirectional attention,
and the predictor estimates representations of masked timesteps.

## Architecture

```
Visible timesteps [batch, seq_len, input_dim]
        |
+========================+
| Context Encoder        |
|  Input projection      |
|  + Positional embed    |
|  Bidirectional Attn×N  |
|  LayerNorm             |
|  Mean Pool             |
+========================+
        |
[batch, embed_dim]  (context representation)

Context Repr [batch, embed_dim]
        |
+========================+
| Predictor              |
|  Project to pred_dim   |
|  MLP blocks × M       |
|  LayerNorm             |
|  Project to embed_dim  |
+========================+
        |
[batch, embed_dim]  (predicted target representation)
```

The target encoder is architecturally identical to the context encoder
with EMA-updated parameters (not part of the computational graph).

## Returns

`{context_encoder, predictor}` — two Axon models.

## Usage

    {context_encoder, predictor} = TemporalJEPA.build(
      input_dim: 128,
      embed_dim: 128,
      predictor_embed_dim: 64,
      seq_len: 60,
      mask_ratio: 0.5
    )

    # Training: encode visible frames, predict masked frame representations
    # Target encoder uses EMA of context encoder weights
    target_params = TemporalJEPA.ema_update(context_params, target_params, momentum: 0.996)

## References

- Bardes et al., "V-JEPA: Latent Video Prediction for Visual Representation
  Learning" (Meta AI, 2024)
- Assran et al., "Self-Supervised Learning from Images with a Joint-Embedding
  Predictive Architecture" (CVPR 2023)

# `build_opt`

```elixir
@type build_opt() ::
  {:input_dim, pos_integer()}
  | {:embed_dim, pos_integer()}
  | {:predictor_embed_dim, pos_integer()}
  | {:encoder_depth, pos_integer()}
  | {:predictor_depth, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:dropout, float()}
  | {:seq_len, pos_integer()}
  | {:mask_ratio, float()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: {Axon.t(), Axon.t()}
```

Build both the context encoder and predictor networks.

## Options

  - `:input_dim` - Input feature dimension per timestep (required)
  - `:embed_dim` - Encoder embedding dimension (default: 128)
  - `:predictor_embed_dim` - Predictor internal dimension (default: 64)
  - `:encoder_depth` - Number of transformer blocks in encoder (default: 4)
  - `:predictor_depth` - Number of MLP blocks in predictor (default: 2)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:dropout` - Dropout rate (default: 0.1)
  - `:seq_len` - Expected sequence length (default: 60)
  - `:mask_ratio` - Fraction of timesteps to mask (default: 0.5)

## Returns

  `{context_encoder, predictor}` tuple of Axon models.

# `build_context_encoder`

```elixir
@spec build_context_encoder(keyword()) :: Axon.t()
```

Build the context encoder for temporal sequences.

Processes visible timesteps through bidirectional self-attention
(no causal mask) and mean-pools to a fixed-size representation.

## Returns

  Axon model: `[batch, seq_len, input_dim]` → `[batch, embed_dim]`

# `build_predictor`

```elixir
@spec build_predictor(keyword()) :: Axon.t()
```

Build the predictor network.

Takes the context encoder output (flat vector) and predicts the target
encoder's representation of masked timesteps.

## Returns

  Axon model: `[batch, embed_dim]` → `[batch, embed_dim]`

# `default_momentum`

```elixir
@spec default_momentum() :: float()
```

Default EMA momentum.

# `ema_update`

```elixir
@spec ema_update(map(), map(), keyword()) :: map()
```

Update target encoder parameters via exponential moving average.

`target = momentum * target + (1 - momentum) * context`

## Parameters

  - `context_params` - Current context encoder parameters
  - `target_params` - Current target encoder parameters

## Options

  - `:momentum` - EMA coefficient (default: 0.996)

## Returns

  Updated target parameters.

# `generate_temporal_mask`

```elixir
@spec generate_temporal_mask(Nx.Tensor.t(), pos_integer(), float()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}
```

Generate a temporal mask for masking timesteps.

Returns a boolean tensor of shape `[seq_len]` where `true` means visible
and `false` means masked.

## Parameters

  - `key` - PRNG key
  - `seq_len` - Number of timesteps
  - `mask_ratio` - Fraction to mask (default: 0.5)

## Returns

  `{visible_mask, key}` where `visible_mask` is `[seq_len]` boolean tensor.

# `loss`

```elixir
@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute temporal JEPA loss (smooth L1 between predicted and target representations).

## Parameters

  - `predicted` - Predictor output: `[batch, embed_dim]`
  - `target` - Target encoder output: `[batch, embed_dim]`

## Returns

  Scalar loss tensor.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output size of the temporal JEPA model.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
