Edifice.Contrastive.BYOL (Edifice v0.2.0)

Copy Markdown View Source

BYOL - Bootstrap Your Own Latent.

Implements BYOL from "Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning" (Grill et al., NeurIPS 2020). BYOL learns representations without negative pairs by using two networks: an online network that is trained, and a target network that is an exponential moving average (EMA) of the online network.

Key Innovations

  • No negative pairs needed: Avoids mode collapse through asymmetric design
  • Online/target architecture: Target network provides stable regression targets
  • Predictor head: The online network has an extra predictor that the target lacks
  • EMA update: Target parameters are a slow-moving average of online parameters

Architecture

Augmented View 1              Augmented View 2
      |                             |
      v                             v
+============+               +============+
|  Online    |               |   Target   |
|  Encoder   |               |   Encoder  |  (EMA of online)
+============+               +============+
      |                             |
      v                             v
+============+               +============+
|  Online    |               |   Target   |
| Projector  |               |  Projector |  (EMA of online)
+============+               +============+
      |                             |
      v                             |
+============+                      |
|  Predictor |                      |
| (online    |                      |
|    only)   |                      |
+============+                      |
      |                             |
      v                             v
     p_i          MSE Loss         z_j
      |                             |
      +----------->.<---------------+

Usage

# Build online and target networks
{online_model, target_model} = BYOL.build(
  encoder_dim: 287,
  projection_dim: 256,
  predictor_dim: 64
)

# After each training step, update target via EMA
target_params = BYOL.ema_update(online_params, target_params, momentum: 0.996)

References

Summary

Types

Options for build/1.

Functions

Build both the online and target BYOL networks.

Build the online network (encoder + projector + predictor).

Build the target network (encoder + projector, no predictor).

Default encoder hidden dimension

Default EMA momentum

Default predictor hidden dimension

Default projection dimension

Update target network parameters via exponential moving average.

Compute the BYOL loss (MSE between normalized online predictions and target projections).

Get the output size of the BYOL model.

Types

build_opt()

@type build_opt() ::
  {:encoder_dim, pos_integer()}
  | {:projection_dim, pos_integer()}
  | {:predictor_dim, pos_integer()}
  | {:hidden_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build both the online and target BYOL networks.

The online network includes encoder + projector + predictor. The target network includes encoder + projector (no predictor).

Options

  • :encoder_dim - Input feature dimension (required)
  • :projection_dim - Projector output dimension (default: 256)
  • :predictor_dim - Predictor hidden dimension (default: 64)
  • :hidden_size - Encoder hidden dimension (default: 256)

Returns

{online_model, target_model} tuple of Axon models.

build_online(opts \\ [])

@spec build_online(keyword()) :: Axon.t()

Build the online network (encoder + projector + predictor).

Options

  • :encoder_dim - Input feature dimension (required)
  • :projection_dim - Projector output dimension (default: 256)
  • :predictor_dim - Predictor hidden dimension (default: 64)
  • :hidden_size - Encoder hidden dimension (default: 256)

Returns

An Axon model mapping inputs to predictor output.

build_target(opts \\ [])

@spec build_target(keyword()) :: Axon.t()

Build the target network (encoder + projector, no predictor).

Target network weights should be initialized as a copy of the online network (excluding the predictor) and updated via EMA.

Options

  • :encoder_dim - Input feature dimension (required)
  • :projection_dim - Projector output dimension (default: 256)
  • :hidden_size - Encoder hidden dimension (default: 256)

Returns

An Axon model mapping inputs to projection output.

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default encoder hidden dimension

default_momentum()

@spec default_momentum() :: float()

Default EMA momentum

default_predictor_dim()

@spec default_predictor_dim() :: pos_integer()

Default predictor hidden dimension

default_projection_dim()

@spec default_projection_dim() :: pos_integer()

Default projection dimension

ema_update(online_params, target_params, opts \\ [])

@spec ema_update(map(), map(), keyword()) :: map()

Update target network parameters via exponential moving average.

target_params = momentum target_params + (1 - momentum) online_params

Parameters

  • online_params - Current online network parameters (map of tensors)
  • target_params - Current target network parameters (map of tensors)

Options

  • :momentum - EMA momentum coefficient (default: 0.996)

Returns

Updated target parameters.

loss(online_pred, target_proj)

@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Compute the BYOL loss (MSE between normalized online predictions and target projections).

Parameters

  • online_pred - Online predictor output: [batch, projection_dim]
  • target_proj - Target projector output: [batch, projection_dim]

Returns

Scalar loss tensor.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of the BYOL model.