VICReg - Variance-Invariance-Covariance Regularization.
Implements VICReg from "VICReg: Variance-Invariance-Covariance Regularization for Self-Supervised Learning" (Bardes et al., ICLR 2022). VICReg prevents representation collapse through three explicit regularization terms applied directly to the embedding vectors, without requiring negative pairs, asymmetric networks, or momentum encoders.
Key Innovations
- Explicit collapse prevention: Three distinct terms each prevent a different mode of collapse
- No architectural tricks: Symmetric architecture, no stop-gradient, no momentum encoder, no negative mining
- Interpretable loss: Each term has a clear geometric meaning
Loss Terms
- Variance (v): Maintain variance of each embedding dimension above a threshold (prevents informational collapse where all embeddings become identical)
- Invariance (i): MSE between embeddings of augmented views (ensures representations are view-invariant)
- Covariance (c): Decorrelate embedding dimensions (prevents dimensional collapse where all dimensions are correlated)
L = lambda * invariance(Z, Z')
+ mu * [variance(Z) + variance(Z')]
+ nu * [covariance(Z) + covariance(Z')]Architecture
Augmented View 1 Augmented View 2
| |
v v
+------------+ +------------+
| Encoder | | Encoder | (shared weights)
+------------+ +------------+
| |
v v
+------------+ +------------+
| Projector | | Projector | (shared weights)
+------------+ +------------+
| |
v v
Z Z'
| |
+------> VICReg Loss <----+Usage
model = VICReg.build(encoder_dim: 287, projection_dim: 256)
# Compute loss between two batches of projections
loss = VICReg.vicreg_loss(z, z_prime,
lambda_inv: 25.0,
mu_var: 25.0,
nu_cov: 1.0
)References
Summary
Functions
Build a VICReg model (encoder + projector).
Covariance term: decorrelate embedding dimensions.
Default encoder/projector hidden dimension
Default invariance loss coefficient
Default variance loss coefficient
Default covariance loss coefficient
Default projection head output dimension
Default variance target threshold
Invariance term: MSE between the two views.
Get the output size of the VICReg model.
Variance term: hinge loss on standard deviation.
Compute the full VICReg loss.
Types
@type build_opt() :: {:encoder_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:projection_dim, pos_integer()}
Options for build/1.
Functions
Build a VICReg model (encoder + projector).
Options
:encoder_dim- Input feature dimension (required):projection_dim- Projector output dimension (default: 256):hidden_size- Hidden dimension for encoder and projector (default: 512)
Returns
An Axon model mapping inputs to projection embeddings.
@spec covariance_loss(Nx.Tensor.t()) :: Nx.Tensor.t()
Covariance term: decorrelate embedding dimensions.
Prevents dimensional collapse by pushing the off-diagonal elements of the covariance matrix toward zero.
@spec default_lambda_inv() :: float()
Default invariance loss coefficient
@spec default_mu_var() :: float()
Default variance loss coefficient
@spec default_nu_cov() :: float()
Default covariance loss coefficient
@spec default_projection_dim() :: pos_integer()
Default projection head output dimension
@spec default_variance_target() :: float()
Default variance target threshold
@spec invariance_loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Invariance term: MSE between the two views.
Encourages representations of augmented views to be similar.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of the VICReg model.
@spec variance_loss(Nx.Tensor.t(), float()) :: Nx.Tensor.t()
Variance term: hinge loss on standard deviation.
Prevents informational collapse by ensuring each dimension maintains variance above a target threshold across the batch.
@spec vicreg_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute the full VICReg loss.
Parameters
z- Embeddings from view 1: [batch, projection_dim]z_prime- Embeddings from view 2: [batch, projection_dim]
Options
:lambda_inv- Invariance loss weight (default: 25.0):mu_var- Variance loss weight (default: 25.0):nu_cov- Covariance loss weight (default: 1.0):variance_target- Target standard deviation (default: 1.0)
Returns
Scalar loss tensor.