# `Edifice.Contrastive.VICReg`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/contrastive/vicreg.ex#L1)

VICReg - Variance-Invariance-Covariance Regularization.

Implements VICReg from "VICReg: Variance-Invariance-Covariance Regularization
for Self-Supervised Learning" (Bardes et al., ICLR 2022). VICReg prevents
representation collapse through three explicit regularization terms applied
directly to the embedding vectors, without requiring negative pairs,
asymmetric networks, or momentum encoders.

## Key Innovations

- **Explicit collapse prevention**: Three distinct terms each prevent a
  different mode of collapse
- **No architectural tricks**: Symmetric architecture, no stop-gradient,
  no momentum encoder, no negative mining
- **Interpretable loss**: Each term has a clear geometric meaning

## Loss Terms

1. **Variance** (v): Maintain variance of each embedding dimension above
   a threshold (prevents informational collapse where all embeddings
   become identical)
2. **Invariance** (i): MSE between embeddings of augmented views
   (ensures representations are view-invariant)
3. **Covariance** (c): Decorrelate embedding dimensions
   (prevents dimensional collapse where all dimensions are correlated)

```
L = lambda * invariance(Z, Z')
  + mu * [variance(Z) + variance(Z')]
  + nu * [covariance(Z) + covariance(Z')]
```

## Architecture

```
Augmented View 1         Augmented View 2
      |                         |
      v                         v
+------------+           +------------+
|  Encoder   |           |  Encoder   |  (shared weights)
+------------+           +------------+
      |                         |
      v                         v
+------------+           +------------+
| Projector  |           | Projector  |  (shared weights)
+------------+           +------------+
      |                         |
      v                         v
     Z                         Z'
      |                         |
      +------> VICReg Loss <----+
```

## Usage

    model = VICReg.build(encoder_dim: 287, projection_dim: 256)

    # Compute loss between two batches of projections
    loss = VICReg.vicreg_loss(z, z_prime,
      lambda_inv: 25.0,
      mu_var: 25.0,
      nu_cov: 1.0
    )

## References
- Paper: https://arxiv.org/abs/2105.04906

# `build_opt`

```elixir
@type build_opt() ::
  {:encoder_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:projection_dim, pos_integer()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a VICReg model (encoder + projector).

## Options
  - `:encoder_dim` - Input feature dimension (required)
  - `:projection_dim` - Projector output dimension (default: 256)
  - `:hidden_size` - Hidden dimension for encoder and projector (default: 512)

## Returns
  An Axon model mapping inputs to projection embeddings.

# `covariance_loss`

```elixir
@spec covariance_loss(Nx.Tensor.t()) :: Nx.Tensor.t()
```

Covariance term: decorrelate embedding dimensions.

Prevents dimensional collapse by pushing the off-diagonal elements
of the covariance matrix toward zero.

# `default_hidden_size`

```elixir
@spec default_hidden_size() :: pos_integer()
```

Default encoder/projector hidden dimension

# `default_lambda_inv`

```elixir
@spec default_lambda_inv() :: float()
```

Default invariance loss coefficient

# `default_mu_var`

```elixir
@spec default_mu_var() :: float()
```

Default variance loss coefficient

# `default_nu_cov`

```elixir
@spec default_nu_cov() :: float()
```

Default covariance loss coefficient

# `default_projection_dim`

```elixir
@spec default_projection_dim() :: pos_integer()
```

Default projection head output dimension

# `default_variance_target`

```elixir
@spec default_variance_target() :: float()
```

Default variance target threshold

# `invariance_loss`

```elixir
@spec invariance_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
```

Invariance term: MSE between the two views.

Encourages representations of augmented views to be similar.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of the VICReg model.

# `variance_loss`

```elixir
@spec variance_loss(Nx.Tensor.t(), float()) :: Nx.Tensor.t()
```

Variance term: hinge loss on standard deviation.

Prevents informational collapse by ensuring each dimension maintains
variance above a target threshold across the batch.

# `vicreg_loss`

```elixir
@spec vicreg_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
```

Compute the full VICReg loss.

## Parameters
  - `z` - Embeddings from view 1: [batch, projection_dim]
  - `z_prime` - Embeddings from view 2: [batch, projection_dim]

## Options
  - `:lambda_inv` - Invariance loss weight (default: 25.0)
  - `:mu_var` - Variance loss weight (default: 25.0)
  - `:nu_cov` - Covariance loss weight (default: 1.0)
  - `:variance_target` - Target standard deviation (default: 1.0)

## Returns
  Scalar loss tensor.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
