SimCLR - Simple Contrastive Learning of Representations.
Implements SimCLR from "A Simple Framework for Contrastive Learning of Visual Representations" (Chen et al., ICML 2020). SimCLR learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss (NT-Xent).
Key Components
- Augmentation: Two random augmentations of each example
- Encoder: Shared backbone that extracts representations
- Projection Head: MLP that maps representations to contrastive space
- NT-Xent Loss: Normalized temperature-scaled cross-entropy
Architecture
Augmented View 1 Augmented View 2
| |
v v
+------------+ +------------+
| Encoder | | Encoder | (shared weights)
+------------+ +------------+
| |
v v
+------------+ +------------+
| Projection | | Projection | (shared weights)
| Head | | Head |
+------------+ +------------+
| |
v v
z_i NT-Xent z_j
| |
+-------> Loss <---------+Usage
model = SimCLR.build(encoder_dim: 287, projection_dim: 128)
# Compute NT-Xent loss between projections
loss = SimCLR.nt_xent_loss(z_i, z_j, temperature: 0.5)References
Summary
Functions
Build a SimCLR model (encoder + projection head).
Default hidden dimension for encoder and projection head
Default projection head output dimension
Default temperature for NT-Xent loss
Compute the NT-Xent (Normalized Temperature-scaled Cross-Entropy) loss.
Get the output size of the SimCLR model.
Types
@type build_opt() :: {:encoder_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:projection_dim, pos_integer()}
Options for build/1.
Functions
Build a SimCLR model (encoder + projection head).
Options
:encoder_dim- Input feature dimension (required):projection_dim- Projection head output dimension (default: 128):hidden_size- Hidden dimension (default: 256)
Returns
An Axon model mapping inputs to projection embeddings.
@spec default_projection_dim() :: pos_integer()
Default projection head output dimension
@spec default_temperature() :: float()
Default temperature for NT-Xent loss
@spec nt_xent_loss(Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute the NT-Xent (Normalized Temperature-scaled Cross-Entropy) loss.
Given embeddings from two views of the same batch, treats (z_i[k], z_j[k]) as positive pairs and all other combinations as negatives.
Parameters
z_i- Embeddings from view 1: [batch, projection_dim]z_j- Embeddings from view 2: [batch, projection_dim]
Options
:temperature- Temperature scaling (default: 0.5)
Returns
Scalar loss tensor.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of the SimCLR model.