# `Edifice.Contrastive.JEPA`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/contrastive/jepa.ex#L1)

JEPA - Joint Embedding Predictive Architecture.

Implements JEPA from "Self-Supervised Learning from Images with a Joint-Embedding
Predictive Architecture" (Assran et al., CVPR 2023). JEPA predicts representations
of masked regions rather than pixel values, learning more abstract features.

## Key Innovations

- **Predict representations, not pixels**: Unlike MAE which reconstructs raw input,
  JEPA predicts the target encoder's representation of masked regions
- **Asymmetric architecture**: A narrow predictor bridges context to target space
- **EMA target**: Target encoder uses exponential moving average of context encoder
  weights (same pattern as BYOL, handled at training time)

## Architecture

```
Input (with mask)
      |
      v
+===================+
| Context Encoder   |  (processes visible patches)
|  Projection       |
|  + Pos Embed      |
|  Transformer x N  |
|  LayerNorm        |
|  Mean Pool        |
+===================+
      |
      v
[batch, embed_dim]   (context representation)

Context Repr + Mask Tokens
      |
      v
+===================+
| Predictor         |  (narrow transformer)
|  Project to       |
|    predictor_dim  |
|  + Pos Embed      |
|  Concat mask tkns |
|  Transformer x M  |
|  LayerNorm        |
|  Project back to  |
|    embed_dim      |
+===================+
      |
      v
[batch, embed_dim]   (predicted target representation)
```

The target encoder is architecturally identical to the context encoder
with EMA-updated parameters (not part of the computational graph).

## Returns

`{context_encoder, predictor}` — two Axon models.

## Usage

    {context_encoder, predictor} = JEPA.build(
      input_dim: 287,
      embed_dim: 256,
      predictor_embed_dim: 128,
      encoder_depth: 6,
      predictor_depth: 4
    )

    # After each training step, update target via EMA
    target_params = JEPA.ema_update(context_params, target_params, momentum: 0.996)

## References

- "Self-Supervised Learning from Images with a Joint-Embedding Predictive
  Architecture" (Assran et al., CVPR 2023)
- arXiv: https://arxiv.org/abs/2301.08243

# `build_opt`

```elixir
@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:encoder_depth, pos_integer()}
  | {:input_dim, pos_integer()}
  | {:mlp_ratio, float()}
  | {:num_heads, pos_integer()}
  | {:predictor_depth, pos_integer()}
  | {:predictor_embed_dim, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: {Axon.t(), Axon.t()}
```

Build both the context encoder and predictor networks.

## Options
  - `:input_dim` - Input feature dimension (required)
  - `:embed_dim` - Encoder embedding dimension (default: 256)
  - `:predictor_embed_dim` - Predictor hidden dimension, narrower than encoder (default: 128)
  - `:encoder_depth` - Number of transformer blocks in encoder (default: 6)
  - `:predictor_depth` - Number of transformer blocks in predictor (default: 4)
  - `:num_heads` - Number of attention heads (default: 8)
  - `:mlp_ratio` - FFN expansion ratio (default: 4.0)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns
  `{context_encoder, predictor}` tuple of Axon models.

# `build_context_encoder`

```elixir
@spec build_context_encoder(keyword()) :: Axon.t()
```

Build the context encoder.

Processes input features through a projection, positional embedding,
and a stack of transformer blocks, then mean-pools to produce a
fixed-size representation.

## Options
  - `:input_dim` - Input feature dimension (required)
  - `:embed_dim` - Output embedding dimension (default: 256)
  - `:encoder_depth` - Number of transformer blocks (default: 6)
  - `:num_heads` - Attention heads (default: 8)
  - `:mlp_ratio` - FFN expansion ratio (default: 4.0)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns
  Axon model mapping `[batch, input_dim]` to `[batch, embed_dim]`.

# `build_predictor`

```elixir
@spec build_predictor(keyword()) :: Axon.t()
```

Build the predictor network.

Takes context encoder output, projects to a narrower dimension,
processes through transformer blocks, and projects back to embed_dim.

## Options
  - `:embed_dim` - Context encoder output dimension (default: 256)
  - `:predictor_embed_dim` - Predictor internal dimension (default: 128)
  - `:predictor_depth` - Number of transformer blocks (default: 4)
  - `:num_heads` - Attention heads (default: 8)
  - `:mlp_ratio` - FFN expansion ratio (default: 4.0)
  - `:dropout` - Dropout rate (default: 0.1)

## Returns
  Axon model mapping `[batch, embed_dim]` to `[batch, embed_dim]`.

# `default_momentum`

```elixir
@spec default_momentum() :: float()
```

Default EMA momentum.

# `ema_update`

```elixir
@spec ema_update(map(), map(), keyword()) :: map()
```

Update target encoder parameters via exponential moving average.

target_params = momentum * target_params + (1 - momentum) * context_params

## Parameters
  - `context_params` - Current context encoder parameters (map of tensors)
  - `target_params` - Current target encoder parameters (map of tensors)

## Options
  - `:momentum` - EMA momentum coefficient (default: 0.996)

## Returns
  Updated target parameters.

# `loss`

```elixir
@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Compute the JEPA loss (smooth L1 / Huber loss between predicted and target representations).

## Parameters
  - `predicted` - Predictor output: [batch, embed_dim]
  - `target` - Target encoder output: [batch, embed_dim]

## Returns
  Scalar loss tensor.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of the JEPA model.

# `recommended_defaults`

```elixir
@spec recommended_defaults() :: keyword()
```

Get recommended defaults.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
