Edifice.Contrastive.VICReg (Edifice v0.2.0)

Copy Markdown View Source

VICReg - Variance-Invariance-Covariance Regularization.

Implements VICReg from "VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning" (Bardes et al., ICLR 2022). VICReg prevents representation collapse through three explicit regularization terms applied directly to the embedding vectors, without requiring negative pairs, asymmetric networks, or momentum encoders.

Key Innovations

  • Explicit collapse prevention: Three distinct terms each prevent a different mode of collapse
  • No architectural tricks: Symmetric architecture, no stop-gradient, no momentum encoder, no negative mining
  • Interpretable loss: Each term has a clear geometric meaning

Loss Terms

  1. Variance (v): Maintain variance of each embedding dimension above a threshold (prevents informational collapse where all embeddings become identical)
  2. Invariance (i): MSE between embeddings of augmented views (ensures representations are view-invariant)
  3. Covariance (c): Decorrelate embedding dimensions (prevents dimensional collapse where all dimensions are correlated)
L = lambda * invariance(Z, Z')
  + mu * [variance(Z) + variance(Z')]
  + nu * [covariance(Z) + covariance(Z')]

Architecture

Augmented View 1         Augmented View 2
      |                         |
      v                         v
+------------+           +------------+
|  Encoder   |           |  Encoder   |  (shared weights)
+------------+           +------------+
      |                         |
      v                         v
+------------+           +------------+
| Projector  |           | Projector  |  (shared weights)
+------------+           +------------+
      |                         |
      v                         v
     Z                         Z'
      |                         |
      +------> VICReg Loss <----+

Usage

model = VICReg.build(encoder_dim: 287, projection_dim: 256)

# Compute loss between two batches of projections
loss = VICReg.vicreg_loss(z, z_prime,
  lambda_inv: 25.0,
  mu_var: 25.0,
  nu_cov: 1.0
)

References

Summary

Types

Options for build/1.

Functions

Build a VICReg model (encoder + projector).

Covariance term: decorrelate embedding dimensions.

Default encoder/projector hidden dimension

Default invariance loss coefficient

Default variance loss coefficient

Default covariance loss coefficient

Default projection head output dimension

Default variance target threshold

Invariance term: MSE between the two views.

Get the output size of the VICReg model.

Variance term: hinge loss on standard deviation.

Compute the full VICReg loss.

Types

build_opt()

@type build_opt() ::
  {:encoder_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:projection_dim, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a VICReg model (encoder + projector).

Options

  • :encoder_dim - Input feature dimension (required)
  • :projection_dim - Projector output dimension (default: 256)
  • :hidden_size - Hidden dimension for encoder and projector (default: 512)

Returns

An Axon model mapping inputs to projection embeddings.

covariance_loss(z)

@spec covariance_loss(Nx.Tensor.t()) :: Nx.Tensor.t()

Covariance term: decorrelate embedding dimensions.

Prevents dimensional collapse by pushing the off-diagonal elements of the covariance matrix toward zero.

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default encoder/projector hidden dimension

default_lambda_inv()

@spec default_lambda_inv() :: float()

Default invariance loss coefficient

default_mu_var()

@spec default_mu_var() :: float()

Default variance loss coefficient

default_nu_cov()

@spec default_nu_cov() :: float()

Default covariance loss coefficient

default_projection_dim()

@spec default_projection_dim() :: pos_integer()

Default projection head output dimension

default_variance_target()

@spec default_variance_target() :: float()

Default variance target threshold

invariance_loss(z, z_prime)

@spec invariance_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Invariance term: MSE between the two views.

Encourages representations of augmented views to be similar.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of the VICReg model.

variance_loss(z, target \\ 1.0)

@spec variance_loss(Nx.Tensor.t(), float()) :: Nx.Tensor.t()

Variance term: hinge loss on standard deviation.

Prevents informational collapse by ensuring each dimension maintains variance above a target threshold across the batch.

vicreg_loss(z, z_prime, opts \\ [])

@spec vicreg_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute the full VICReg loss.

Parameters

  • z - Embeddings from view 1: [batch, projection_dim]
  • z_prime - Embeddings from view 2: [batch, projection_dim]

Options

  • :lambda_inv - Invariance loss weight (default: 25.0)
  • :mu_var - Variance loss weight (default: 25.0)
  • :nu_cov - Covariance loss weight (default: 1.0)
  • :variance_target - Target standard deviation (default: 1.0)

Returns

Scalar loss tensor.