JEPA - Joint Embedding Predictive Architecture.
Implements JEPA from "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" (Assran et al., CVPR 2023). JEPA predicts representations of masked regions rather than pixel values, learning more abstract features.
Key Innovations
- Predict representations, not pixels: Unlike MAE which reconstructs raw input, JEPA predicts the target encoder's representation of masked regions
- Asymmetric architecture: A narrow predictor bridges context to target space
- EMA target: Target encoder uses exponential moving average of context encoder weights (same pattern as BYOL, handled at training time)
Architecture
Input (with mask)
|
v
+===================+
| Context Encoder | (processes visible patches)
| Projection |
| + Pos Embed |
| Transformer x N |
| LayerNorm |
| Mean Pool |
+===================+
|
v
[batch, embed_dim] (context representation)
Context Repr + Mask Tokens
|
v
+===================+
| Predictor | (narrow transformer)
| Project to |
| predictor_dim |
| + Pos Embed |
| Concat mask tkns |
| Transformer x M |
| LayerNorm |
| Project back to |
| embed_dim |
+===================+
|
v
[batch, embed_dim] (predicted target representation)The target encoder is architecturally identical to the context encoder with EMA-updated parameters (not part of the computational graph).
Returns
{context_encoder, predictor} — two Axon models.
Usage
{context_encoder, predictor} = JEPA.build(
input_dim: 287,
embed_dim: 256,
predictor_embed_dim: 128,
encoder_depth: 6,
predictor_depth: 4
)
# After each training step, update target via EMA
target_params = JEPA.ema_update(context_params, target_params, momentum: 0.996)References
- "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" (Assran et al., CVPR 2023)
- arXiv: https://arxiv.org/abs/2301.08243
Summary
Functions
Build both the context encoder and predictor networks.
Build the context encoder.
Build the predictor network.
Default EMA momentum.
Update target encoder parameters via exponential moving average.
Compute the JEPA loss (smooth L1 / Huber loss between predicted and target representations).
Get the output size of the JEPA model.
Get recommended defaults.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:encoder_depth, pos_integer()} | {:input_dim, pos_integer()} | {:mlp_ratio, float()} | {:num_heads, pos_integer()} | {:predictor_depth, pos_integer()} | {:predictor_embed_dim, pos_integer()}
Options for build/1.
Functions
Build both the context encoder and predictor networks.
Options
:input_dim- Input feature dimension (required):embed_dim- Encoder embedding dimension (default: 256):predictor_embed_dim- Predictor hidden dimension, narrower than encoder (default: 128):encoder_depth- Number of transformer blocks in encoder (default: 6):predictor_depth- Number of transformer blocks in predictor (default: 4):num_heads- Number of attention heads (default: 8):mlp_ratio- FFN expansion ratio (default: 4.0):dropout- Dropout rate (default: 0.1)
Returns
{context_encoder, predictor} tuple of Axon models.
Build the context encoder.
Processes input features through a projection, positional embedding, and a stack of transformer blocks, then mean-pools to produce a fixed-size representation.
Options
:input_dim- Input feature dimension (required):embed_dim- Output embedding dimension (default: 256):encoder_depth- Number of transformer blocks (default: 6):num_heads- Attention heads (default: 8):mlp_ratio- FFN expansion ratio (default: 4.0):dropout- Dropout rate (default: 0.1)
Returns
Axon model mapping [batch, input_dim] to [batch, embed_dim].
Build the predictor network.
Takes context encoder output, projects to a narrower dimension, processes through transformer blocks, and projects back to embed_dim.
Options
:embed_dim- Context encoder output dimension (default: 256):predictor_embed_dim- Predictor internal dimension (default: 128):predictor_depth- Number of transformer blocks (default: 4):num_heads- Attention heads (default: 8):mlp_ratio- FFN expansion ratio (default: 4.0):dropout- Dropout rate (default: 0.1)
Returns
Axon model mapping [batch, embed_dim] to [batch, embed_dim].
@spec default_momentum() :: float()
Default EMA momentum.
Update target encoder parameters via exponential moving average.
target_params = momentum target_params + (1 - momentum) context_params
Parameters
context_params- Current context encoder parameters (map of tensors)target_params- Current target encoder parameters (map of tensors)
Options
:momentum- EMA momentum coefficient (default: 0.996)
Returns
Updated target parameters.
@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the JEPA loss (smooth L1 / Huber loss between predicted and target representations).
Parameters
predicted- Predictor output: [batch, embed_dim]target- Target encoder output: [batch, embed_dim]
Returns
Scalar loss tensor.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of the JEPA model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.