Edifice.Contrastive.TemporalJEPA (Edifice v0.2.0)

Copy Markdown View Source

Temporal JEPA — Joint Embedding Predictive Architecture for sequences.

Extends JEPA to temporal data (video frames, time series, trajectories). The context encoder processes visible timesteps with bidirectional attention, and the predictor estimates representations of masked timesteps.

Architecture

Visible timesteps [batch, seq_len, input_dim]
        |
+========================+
| Context Encoder        |
|  Input projection      |
|  + Positional embed    |
|  Bidirectional Attn×N  |
|  LayerNorm             |
|  Mean Pool             |
+========================+
        |
[batch, embed_dim]  (context representation)

Context Repr [batch, embed_dim]
        |
+========================+
| Predictor              |
|  Project to pred_dim   |
|  MLP blocks × M       |
|  LayerNorm             |
|  Project to embed_dim  |
+========================+
        |
[batch, embed_dim]  (predicted target representation)

The target encoder is architecturally identical to the context encoder with EMA-updated parameters (not part of the computational graph).

Returns

{context_encoder, predictor} — two Axon models.

Usage

{context_encoder, predictor} = TemporalJEPA.build(
  input_dim: 128,
  embed_dim: 128,
  predictor_embed_dim: 64,
  seq_len: 60,
  mask_ratio: 0.5
)

# Training: encode visible frames, predict masked frame representations
# Target encoder uses EMA of context encoder weights
target_params = TemporalJEPA.ema_update(context_params, target_params, momentum: 0.996)

References

  • Bardes et al., "V-JEPA: Latent Video Prediction for Visual Representation Learning" (Meta AI, 2024)
  • Assran et al., "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" (CVPR 2023)

Summary

Types

Options for build/1.

Functions

Build both the context encoder and predictor networks.

Build the context encoder for temporal sequences.

Build the predictor network.

Default EMA momentum.

Update target encoder parameters via exponential moving average.

Generate a temporal mask for masking timesteps.

Compute temporal JEPA loss (smooth L1 between predicted and target representations).

Get the output size of the temporal JEPA model.

Types

build_opt()

@type build_opt() ::
  {:input_dim, pos_integer()}
  | {:embed_dim, pos_integer()}
  | {:predictor_embed_dim, pos_integer()}
  | {:encoder_depth, pos_integer()}
  | {:predictor_depth, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:dropout, float()}
  | {:seq_len, pos_integer()}
  | {:mask_ratio, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build both the context encoder and predictor networks.

Options

  • :input_dim - Input feature dimension per timestep (required)
  • :embed_dim - Encoder embedding dimension (default: 128)
  • :predictor_embed_dim - Predictor internal dimension (default: 64)
  • :encoder_depth - Number of transformer blocks in encoder (default: 4)
  • :predictor_depth - Number of MLP blocks in predictor (default: 2)
  • :num_heads - Number of attention heads (default: 8)
  • :dropout - Dropout rate (default: 0.1)
  • :seq_len - Expected sequence length (default: 60)
  • :mask_ratio - Fraction of timesteps to mask (default: 0.5)

Returns

{context_encoder, predictor} tuple of Axon models.

build_context_encoder(opts \\ [])

@spec build_context_encoder(keyword()) :: Axon.t()

Build the context encoder for temporal sequences.

Processes visible timesteps through bidirectional self-attention (no causal mask) and mean-pools to a fixed-size representation.

Returns

Axon model: [batch, seq_len, input_dim][batch, embed_dim]

build_predictor(opts \\ [])

@spec build_predictor(keyword()) :: Axon.t()

Build the predictor network.

Takes the context encoder output (flat vector) and predicts the target encoder's representation of masked timesteps.

Returns

Axon model: [batch, embed_dim][batch, embed_dim]

default_momentum()

@spec default_momentum() :: float()

Default EMA momentum.

ema_update(context_params, target_params, opts \\ [])

@spec ema_update(map(), map(), keyword()) :: map()

Update target encoder parameters via exponential moving average.

target = momentum * target + (1 - momentum) * context

Parameters

  • context_params - Current context encoder parameters
  • target_params - Current target encoder parameters

Options

  • :momentum - EMA coefficient (default: 0.996)

Returns

Updated target parameters.

generate_temporal_mask(key, seq_len, mask_ratio \\ 0.5)

@spec generate_temporal_mask(Nx.Tensor.t(), pos_integer(), float()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Generate a temporal mask for masking timesteps.

Returns a boolean tensor of shape [seq_len] where true means visible and false means masked.

Parameters

  • key - PRNG key
  • seq_len - Number of timesteps
  • mask_ratio - Fraction to mask (default: 0.5)

Returns

{visible_mask, key} where visible_mask is [seq_len] boolean tensor.

loss(predicted, target)

@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute temporal JEPA loss (smooth L1 between predicted and target representations).

Parameters

  • predicted - Predictor output: [batch, embed_dim]
  • target - Target encoder output: [batch, embed_dim]

Returns

Scalar loss tensor.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the temporal JEPA model.