BYOL - Bootstrap Your Own Latent.
Implements BYOL from "Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning" (Grill et al., NeurIPS 2020). BYOL learns representations without negative pairs by using two networks: an online network that is trained, and a target network that is an exponential moving average (EMA) of the online network.
Key Innovations
- No negative pairs needed: Avoids mode collapse through asymmetric design
- Online/target architecture: Target network provides stable regression targets
- Predictor head: The online network has an extra predictor that the target lacks
- EMA update: Target parameters are a slow-moving average of online parameters
Architecture
Augmented View 1 Augmented View 2
| |
v v
+============+ +============+
| Online | | Target |
| Encoder | | Encoder | (EMA of online)
+============+ +============+
| |
v v
+============+ +============+
| Online | | Target |
| Projector | | Projector | (EMA of online)
+============+ +============+
| |
v |
+============+ |
| Predictor | |
| (online | |
| only) | |
+============+ |
| |
v v
p_i MSE Loss z_j
| |
+----------->.<---------------+Usage
# Build online and target networks
{online_model, target_model} = BYOL.build(
encoder_dim: 287,
projection_dim: 256,
predictor_dim: 64
)
# After each training step, update target via EMA
target_params = BYOL.ema_update(online_params, target_params, momentum: 0.996)References
Summary
Functions
Build both the online and target BYOL networks.
Build the online network (encoder + projector + predictor).
Build the target network (encoder + projector, no predictor).
Default encoder hidden dimension
Default EMA momentum
Default predictor hidden dimension
Default projection dimension
Update target network parameters via exponential moving average.
Compute the BYOL loss (MSE between normalized online predictions and target projections).
Get the output size of the BYOL model.
Types
@type build_opt() :: {:encoder_dim, pos_integer()} | {:projection_dim, pos_integer()} | {:predictor_dim, pos_integer()} | {:hidden_size, pos_integer()}
Options for build/1.
Functions
Build both the online and target BYOL networks.
The online network includes encoder + projector + predictor. The target network includes encoder + projector (no predictor).
Options
:encoder_dim- Input feature dimension (required):projection_dim- Projector output dimension (default: 256):predictor_dim- Predictor hidden dimension (default: 64):hidden_size- Encoder hidden dimension (default: 256)
Returns
{online_model, target_model} tuple of Axon models.
Build the online network (encoder + projector + predictor).
Options
:encoder_dim- Input feature dimension (required):projection_dim- Projector output dimension (default: 256):predictor_dim- Predictor hidden dimension (default: 64):hidden_size- Encoder hidden dimension (default: 256)
Returns
An Axon model mapping inputs to predictor output.
Build the target network (encoder + projector, no predictor).
Target network weights should be initialized as a copy of the online network (excluding the predictor) and updated via EMA.
Options
:encoder_dim- Input feature dimension (required):projection_dim- Projector output dimension (default: 256):hidden_size- Encoder hidden dimension (default: 256)
Returns
An Axon model mapping inputs to projection output.
@spec default_momentum() :: float()
Default EMA momentum
@spec default_predictor_dim() :: pos_integer()
Default predictor hidden dimension
@spec default_projection_dim() :: pos_integer()
Default projection dimension
Update target network parameters via exponential moving average.
target_params = momentum target_params + (1 - momentum) online_params
Parameters
online_params- Current online network parameters (map of tensors)target_params- Current target network parameters (map of tensors)
Options
:momentum- EMA momentum coefficient (default: 0.996)
Returns
Updated target parameters.
@spec loss(Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute the BYOL loss (MSE between normalized online predictions and target projections).
Parameters
online_pred- Online predictor output: [batch, projection_dim]target_proj- Target projector output: [batch, projection_dim]
Returns
Scalar loss tensor.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of the BYOL model.