Edifice.Contrastive.SimCLR (Edifice v0.2.0)

Copy Markdown View Source

SimCLR - Simple Contrastive Learning of Representations.

Implements SimCLR from "A Simple Framework for Contrastive Learning of Visual Representations" (Chen et al., ICML 2020). SimCLR learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss (NT-Xent).

Key Components

  • Augmentation: Two random augmentations of each example
  • Encoder: Shared backbone that extracts representations
  • Projection Head: MLP that maps representations to contrastive space
  • NT-Xent Loss: Normalized temperature-scaled cross-entropy

Architecture

Augmented View 1         Augmented View 2
      |                         |
      v                         v
+------------+           +------------+
|  Encoder   |           |  Encoder   |  (shared weights)
+------------+           +------------+
      |                         |
      v                         v
+------------+           +------------+
| Projection |           | Projection |  (shared weights)
|    Head    |           |    Head    |
+------------+           +------------+
      |                         |
      v                         v
     z_i        NT-Xent        z_j
      |                         |
      +-------> Loss <---------+

Usage

model = SimCLR.build(encoder_dim: 287, projection_dim: 128)

# Compute NT-Xent loss between projections
loss = SimCLR.nt_xent_loss(z_i, z_j, temperature: 0.5)

References

Summary

Types

Options for build/1.

Functions

Build a SimCLR model (encoder + projection head).

Default hidden dimension for encoder and projection head

Default projection head output dimension

Default temperature for NT-Xent loss

Compute the NT-Xent (Normalized Temperature-scaled Cross-Entropy) loss.

Get the output size of the SimCLR model.

Types

build_opt()

@type build_opt() ::
  {:encoder_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:projection_dim, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a SimCLR model (encoder + projection head).

Options

  • :encoder_dim - Input feature dimension (required)
  • :projection_dim - Projection head output dimension (default: 128)
  • :hidden_size - Hidden dimension (default: 256)

Returns

An Axon model mapping inputs to projection embeddings.

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension for encoder and projection head

default_projection_dim()

@spec default_projection_dim() :: pos_integer()

Default projection head output dimension

default_temperature()

@spec default_temperature() :: float()

Default temperature for NT-Xent loss

nt_xent_loss(z_i, z_j, opts \\ [])

@spec nt_xent_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute the NT-Xent (Normalized Temperature-scaled Cross-Entropy) loss.

Given embeddings from two views of the same batch, treats (z_i[k], z_j[k]) as positive pairs and all other combinations as negatives.

Parameters

  • z_i - Embeddings from view 1: [batch, projection_dim]
  • z_j - Embeddings from view 2: [batch, projection_dim]

Options

  • :temperature - Temperature scaling (default: 0.5)

Returns

Scalar loss tensor.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of the SimCLR model.