Edifice.Contrastive.BarlowTwins (Edifice v0.2.0)

Copy Markdown View Source

Barlow Twins - Redundancy Reduction for Self-Supervised Learning.

Implements Barlow Twins from "Barlow Twins: Self-Supervised Learning via Redundancy Reduction" (Zbontar et al., ICML 2021). Barlow Twins prevents representation collapse by pushing the cross-correlation matrix of two augmented views toward the identity matrix.

Key Innovation

The loss has two terms on the cross-correlation matrix C:

  1. Invariance: Diagonal elements should be 1 (views agree per dimension)
  2. Redundancy reduction: Off-diagonal elements should be 0 (dimensions should be independent)
L = SUM_i (1 - C_ii)^2 + lambda * SUM_{i!=j} C_ij^2

Architecture

Augmented View 1         Augmented View 2
      |                         |
      v                         v
+------------+           +------------+
|  Encoder   |           |  Encoder   |  (shared weights)
+------------+           +------------+
      |                         |
      v                         v
+------------+           +------------+
| Projector  |           | Projector  |  (shared weights)
+------------+           +------------+
      |                         |
      v                         v
     Z_A        Cross-Corr     Z_B
      |                         |
      +-------> Loss <---------+

Usage

model = BarlowTwins.build(encoder_dim: 287, projection_dim: 256)

loss = BarlowTwins.barlow_twins_loss(z_a, z_b, lambda_param: 0.005)

References

Summary

Types

Options for build/1.

Functions

Compute the Barlow Twins loss.

Build a Barlow Twins model (encoder + projector).

Default hidden dimension

Default redundancy reduction coefficient

Default projection dimension

Get the output size of the Barlow Twins model.

Types

build_opt()

@type build_opt() ::
  {:encoder_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:projection_dim, pos_integer()}

Options for build/1.

Functions

barlow_twins_loss(z_a, z_b, opts \\ [])

@spec barlow_twins_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute the Barlow Twins loss.

Parameters

  • z_a - Embeddings from view A: [batch, projection_dim]
  • z_b - Embeddings from view B: [batch, projection_dim]

Options

  • :lambda_param - Off-diagonal penalty weight (default: 0.005)

Returns

Scalar loss tensor.

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Barlow Twins model (encoder + projector).

Options

  • :encoder_dim - Input feature dimension (required)
  • :projection_dim - Projector output dimension (default: 256)
  • :hidden_size - Hidden dimension (default: 512)

Returns

An Axon model mapping inputs to projection embeddings.

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_lambda()

@spec default_lambda() :: float()

Default redundancy reduction coefficient

default_projection_dim()

@spec default_projection_dim() :: pos_integer()

Default projection dimension

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of the Barlow Twins model.