Temporal JEPA — Joint Embedding Predictive Architecture for sequences.
Extends JEPA to temporal data (video frames, time series, trajectories). The context encoder processes visible timesteps with bidirectional attention, and the predictor estimates representations of masked timesteps.
Architecture
Visible timesteps [batch, seq_len, input_dim]
|
+========================+
| Context Encoder |
| Input projection |
| + Positional embed |
| Bidirectional Attn×N |
| LayerNorm |
| Mean Pool |
+========================+
|
[batch, embed_dim] (context representation)
Context Repr [batch, embed_dim]
|
+========================+
| Predictor |
| Project to pred_dim |
| MLP blocks × M |
| LayerNorm |
| Project to embed_dim |
+========================+
|
[batch, embed_dim] (predicted target representation)The target encoder is architecturally identical to the context encoder with EMA-updated parameters (not part of the computational graph).
Returns
{context_encoder, predictor} — two Axon models.
Usage
{context_encoder, predictor} = TemporalJEPA.build(
input_dim: 128,
embed_dim: 128,
predictor_embed_dim: 64,
seq_len: 60,
mask_ratio: 0.5
)
# Training: encode visible frames, predict masked frame representations
# Target encoder uses EMA of context encoder weights
target_params = TemporalJEPA.ema_update(context_params, target_params, momentum: 0.996)References
- Bardes et al., "V-JEPA: Latent Video Prediction for Visual Representation Learning" (Meta AI, 2024)
- Assran et al., "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" (CVPR 2023)
Summary
Functions
Build both the context encoder and predictor networks.
Build the context encoder for temporal sequences.
Build the predictor network.
Default EMA momentum.
Update target encoder parameters via exponential moving average.
Generate a temporal mask for masking timesteps.
Compute temporal JEPA loss (smooth L1 between predicted and target representations).
Get the output size of the temporal JEPA model.
Types
@type build_opt() :: {:input_dim, pos_integer()} | {:embed_dim, pos_integer()} | {:predictor_embed_dim, pos_integer()} | {:encoder_depth, pos_integer()} | {:predictor_depth, pos_integer()} | {:num_heads, pos_integer()} | {:dropout, float()} | {:seq_len, pos_integer()} | {:mask_ratio, float()}
Options for build/1.
Functions
Build both the context encoder and predictor networks.
Options
:input_dim- Input feature dimension per timestep (required):embed_dim- Encoder embedding dimension (default: 128):predictor_embed_dim- Predictor internal dimension (default: 64):encoder_depth- Number of transformer blocks in encoder (default: 4):predictor_depth- Number of MLP blocks in predictor (default: 2):num_heads- Number of attention heads (default: 8):dropout- Dropout rate (default: 0.1):seq_len- Expected sequence length (default: 60):mask_ratio- Fraction of timesteps to mask (default: 0.5)
Returns
{context_encoder, predictor} tuple of Axon models.
Build the context encoder for temporal sequences.
Processes visible timesteps through bidirectional self-attention (no causal mask) and mean-pools to a fixed-size representation.
Returns
Axon model: [batch, seq_len, input_dim] → [batch, embed_dim]
Build the predictor network.
Takes the context encoder output (flat vector) and predicts the target encoder's representation of masked timesteps.
Returns
Axon model: [batch, embed_dim] → [batch, embed_dim]
@spec default_momentum() :: float()
Default EMA momentum.
Update target encoder parameters via exponential moving average.
target = momentum * target + (1 - momentum) * context
Parameters
context_params- Current context encoder parameterstarget_params- Current target encoder parameters
Options
:momentum- EMA coefficient (default: 0.996)
Returns
Updated target parameters.
@spec generate_temporal_mask(Nx.Tensor.t(), pos_integer(), float()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Generate a temporal mask for masking timesteps.
Returns a boolean tensor of shape [seq_len] where true means visible
and false means masked.
Parameters
key- PRNG keyseq_len- Number of timestepsmask_ratio- Fraction to mask (default: 0.5)
Returns
{visible_mask, key} where visible_mask is [seq_len] boolean tensor.
@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute temporal JEPA loss (smooth L1 between predicted and target representations).
Parameters
predicted- Predictor output:[batch, embed_dim]target- Target encoder output:[batch, embed_dim]
Returns
Scalar loss tensor.
@spec output_size(keyword()) :: pos_integer()
Get the output size of the temporal JEPA model.