Barlow Twins - Redundancy Reduction for Self-Supervised Learning.
Implements Barlow Twins from "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" (Zbontar et al., ICML 2021). Barlow Twins prevents representation collapse by pushing the cross-correlation matrix of two augmented views toward the identity matrix.
Key Innovation
The loss has two terms on the cross-correlation matrix C:
- Invariance: Diagonal elements should be 1 (views agree per dimension)
- Redundancy reduction: Off-diagonal elements should be 0 (dimensions should be independent)
L = SUM_i (1 - C_ii)^2 + lambda * SUM_{i!=j} C_ij^2Architecture
Augmented View 1 Augmented View 2
| |
v v
+------------+ +------------+
| Encoder | | Encoder | (shared weights)
+------------+ +------------+
| |
v v
+------------+ +------------+
| Projector | | Projector | (shared weights)
+------------+ +------------+
| |
v v
Z_A Cross-Corr Z_B
| |
+-------> Loss <---------+Usage
model = BarlowTwins.build(encoder_dim: 287, projection_dim: 256)
loss = BarlowTwins.barlow_twins_loss(z_a, z_b, lambda_param: 0.005)References
Summary
Functions
Compute the Barlow Twins loss.
Build a Barlow Twins model (encoder + projector).
Default hidden dimension
Default redundancy reduction coefficient
Default projection dimension
Get the output size of the Barlow Twins model.
Types
@type build_opt() :: {:encoder_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:projection_dim, pos_integer()}
Options for build/1.
Functions
@spec barlow_twins_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute the Barlow Twins loss.
Parameters
z_a- Embeddings from view A: [batch, projection_dim]z_b- Embeddings from view B: [batch, projection_dim]
Options
:lambda_param- Off-diagonal penalty weight (default: 0.005)
Returns
Scalar loss tensor.
Build a Barlow Twins model (encoder + projector).
Options
:encoder_dim- Input feature dimension (required):projection_dim- Projector output dimension (default: 256):hidden_size- Hidden dimension (default: 512)
Returns
An Axon model mapping inputs to projection embeddings.
@spec default_lambda() :: float()
Default redundancy reduction coefficient
@spec default_projection_dim() :: pos_integer()
Default projection dimension
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of the Barlow Twins model.