Edifice.Contrastive.JEPA (Edifice v0.2.0)

Copy Markdown View Source

JEPA - Joint Embedding Predictive Architecture.

Implements JEPA from "Self-Supervised Learning from Images with a Joint-Embedding Predictive Architecture" (Assran et al., CVPR 2023). JEPA predicts representations of masked regions rather than pixel values, learning more abstract features.

Key Innovations

  • Predict representations, not pixels: Unlike MAE which reconstructs raw input, JEPA predicts the target encoder's representation of masked regions
  • Asymmetric architecture: A narrow predictor bridges context to target space
  • EMA target: Target encoder uses exponential moving average of context encoder weights (same pattern as BYOL, handled at training time)

Architecture

Input (with mask)
      |
      v
+===================+
| Context Encoder   |  (processes visible patches)
|  Projection       |
|  + Pos Embed      |
|  Transformer x N  |
|  LayerNorm        |
|  Mean Pool        |
+===================+
      |
      v
[batch, embed_dim]   (context representation)

Context Repr + Mask Tokens
      |
      v
+===================+
| Predictor         |  (narrow transformer)
|  Project to       |
|    predictor_dim  |
|  + Pos Embed      |
|  Concat mask tkns |
|  Transformer x M  |
|  LayerNorm        |
|  Project back to  |
|    embed_dim      |
+===================+
      |
      v
[batch, embed_dim]   (predicted target representation)

The target encoder is architecturally identical to the context encoder with EMA-updated parameters (not part of the computational graph).

Returns

{context_encoder, predictor} — two Axon models.

Usage

{context_encoder, predictor} = JEPA.build(
  input_dim: 287,
  embed_dim: 256,
  predictor_embed_dim: 128,
  encoder_depth: 6,
  predictor_depth: 4
)

# After each training step, update target via EMA
target_params = JEPA.ema_update(context_params, target_params, momentum: 0.996)

References

Summary

Types

Options for build/1.

Functions

Build both the context encoder and predictor networks.

Build the context encoder.

Build the predictor network.

Default EMA momentum.

Update target encoder parameters via exponential moving average.

Compute the JEPA loss (smooth L1 / Huber loss between predicted and target representations).

Get the output size of the JEPA model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:encoder_depth, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_heads, pos_integer()}
  | {:predictor_depth, pos_integer()}
  | {:predictor_embed_dim, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build both the context encoder and predictor networks.

Options

  • :input_dim - Input feature dimension (required)
  • :embed_dim - Encoder embedding dimension (default: 256)
  • :predictor_embed_dim - Predictor hidden dimension, narrower than encoder (default: 128)
  • :encoder_depth - Number of transformer blocks in encoder (default: 6)
  • :predictor_depth - Number of transformer blocks in predictor (default: 4)
  • :num_heads - Number of attention heads (default: 8)
  • :mlp_ratio - FFN expansion ratio (default: 4.0)
  • :dropout - Dropout rate (default: 0.1)

Returns

{context_encoder, predictor} tuple of Axon models.

build_context_encoder(opts \\ [])

@spec build_context_encoder(keyword()) :: Axon.t()

Build the context encoder.

Processes input features through a projection, positional embedding, and a stack of transformer blocks, then mean-pools to produce a fixed-size representation.

Options

  • :input_dim - Input feature dimension (required)
  • :embed_dim - Output embedding dimension (default: 256)
  • :encoder_depth - Number of transformer blocks (default: 6)
  • :num_heads - Attention heads (default: 8)
  • :mlp_ratio - FFN expansion ratio (default: 4.0)
  • :dropout - Dropout rate (default: 0.1)

Returns

Axon model mapping [batch, input_dim] to [batch, embed_dim].

build_predictor(opts \\ [])

@spec build_predictor(keyword()) :: Axon.t()

Build the predictor network.

Takes context encoder output, projects to a narrower dimension, processes through transformer blocks, and projects back to embed_dim.

Options

  • :embed_dim - Context encoder output dimension (default: 256)
  • :predictor_embed_dim - Predictor internal dimension (default: 128)
  • :predictor_depth - Number of transformer blocks (default: 4)
  • :num_heads - Attention heads (default: 8)
  • :mlp_ratio - FFN expansion ratio (default: 4.0)
  • :dropout - Dropout rate (default: 0.1)

Returns

Axon model mapping [batch, embed_dim] to [batch, embed_dim].

default_momentum()

@spec default_momentum() :: float()

Default EMA momentum.

ema_update(context_params, target_params, opts \\ [])

@spec ema_update(map(), map(), keyword()) :: map()

Update target encoder parameters via exponential moving average.

target_params = momentum target_params + (1 - momentum) context_params

Parameters

  • context_params - Current context encoder parameters (map of tensors)
  • target_params - Current target encoder parameters (map of tensors)

Options

  • :momentum - EMA momentum coefficient (default: 0.996)

Returns

Updated target parameters.

loss(predicted, target)

@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the JEPA loss (smooth L1 / Huber loss between predicted and target representations).

Parameters

  • predicted - Predictor output: [batch, embed_dim]
  • target - Target encoder output: [batch, embed_dim]

Returns

Scalar loss tensor.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of the JEPA model.