Edifice.Contrastive.MAE (Edifice v0.2.0)

Copy Markdown View Source

MAE - Masked Autoencoder.

Implements the Masked Autoencoder from "Masked Autoencoders Are Scalable Vision Learners" (He et al., CVPR 2022), adapted for 1D sequence data. MAE masks a large portion of input patches (tokens) and trains an autoencoder to reconstruct the missing patches.

Key Innovations

  • High masking ratio: Masking 75% of patches creates a challenging pretext task that forces learning of strong representations
  • Asymmetric encoder-decoder: Encoder only processes unmasked patches (efficient), decoder is lightweight and processes all patches
  • Mask tokens: Learned placeholder tokens are added for masked positions before the decoder

Architecture

Input Patches [batch, num_patches, input_dim]
      |
      v
[Random Masking] (keep 25%)
      |
      v
+------------------+
|     Encoder      |  (processes only unmasked patches)
| (deeper, wider)  |
+------------------+
      |
      v
[Add Mask Tokens]  (insert learnable tokens at masked positions)
      |
      v
+------------------+
|     Decoder      |  (processes all patches)
| (shallow, narrow)|
+------------------+
      |
      v
[Reconstruction]   MSE loss on masked patches only

Usage

# Build encoder and decoder
{encoder, decoder} = MAE.build(
  input_dim: 64,
  embed_dim: 256,
  decoder_dim: 128,
  mask_ratio: 0.75,
  num_encoder_layers: 4,
  num_decoder_layers: 2
)

# For pretraining: use both encoder + decoder with masking
# For downstream: use encoder only (discard decoder)

References

Summary

Types

Options for build/1.

Functions

Build the MAE encoder and decoder.

Build the MAE decoder.

Build the MAE encoder.

Default decoder dimension

Default encoder embedding dimension

Default feedforward expansion factor

Default masking ratio

Default number of decoder layers

Default number of encoder layers

Generate a random mask for input patches.

Get the output size of the MAE encoder.

Compute reconstruction loss on masked patches only.

Types

build_opt()

@type build_opt() ::
  {:input_dim, pos_integer()}
  | {:embed_dim, pos_integer()}
  | {:num_encoder_layers, pos_integer()}
  | {:num_decoder_layers, pos_integer()}
  | {:decoder_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:num_patches, pos_integer() | nil}
  | {:mask_ratio, float()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: {Axon.t(), Axon.t()}

Build the MAE encoder and decoder.

Options

  • :input_dim - Dimension of each input patch/token (required)
  • :embed_dim - Encoder embedding dimension (default: 256)
  • :decoder_dim - Decoder hidden dimension (default: 128)
  • :mask_ratio - Fraction of patches to mask (default: 0.75)
  • :num_encoder_layers - Number of encoder layers (default: 4)
  • :num_decoder_layers - Number of decoder layers (default: 2)
  • :num_patches - Number of input patches/tokens (default: nil for dynamic)

Returns

{encoder, decoder} tuple of Axon models.

build_decoder(opts \\ [])

@spec build_decoder(keyword()) :: Axon.t()

Build the MAE decoder.

The decoder takes encoder output (with mask tokens inserted at masked positions) and reconstructs all patches.

Options

  • :input_dim - Original patch dimension for reconstruction target (required)
  • :embed_dim - Encoder output dimension / decoder input (default: 256)
  • :decoder_dim - Decoder hidden dimension (default: 128)
  • :num_decoder_layers - Number of decoder layers (default: 2)
  • :num_patches - Total number of patches (default: nil)

Returns

An Axon model: [batch, num_patches, embed_dim] -> [batch, num_patches, input_dim]

build_encoder(opts \\ [])

@spec build_encoder(keyword()) :: Axon.t()

Build the MAE encoder.

The encoder processes only unmasked patches. In the full MAE pipeline, masking is applied externally before feeding to the encoder.

Options

  • :input_dim - Dimension of each input patch (required)
  • :embed_dim - Encoder embedding dimension (default: 256)
  • :num_encoder_layers - Number of encoder layers (default: 4)
  • :num_patches - Sequence length (default: nil)

Returns

An Axon model: [batch, num_visible, input_dim] -> [batch, num_visible, embed_dim]

default_decoder_dim()

@spec default_decoder_dim() :: pos_integer()

Default decoder dimension

default_embed_dim()

@spec default_embed_dim() :: pos_integer()

Default encoder embedding dimension

default_expand_factor()

@spec default_expand_factor() :: pos_integer()

Default feedforward expansion factor

default_mask_ratio()

@spec default_mask_ratio() :: float()

Default masking ratio

default_num_decoder_layers()

@spec default_num_decoder_layers() :: pos_integer()

Default number of decoder layers

default_num_encoder_layers()

@spec default_num_encoder_layers() :: pos_integer()

Default number of encoder layers

generate_mask(num_patches, mask_ratio \\ default_mask_ratio())

@spec generate_mask(non_neg_integer(), float()) :: {Nx.Tensor.t(), Nx.Tensor.t()}

Generate a random mask for input patches.

Returns a tuple of {visible_indices, masked_indices} for a given number of patches and mask ratio.

Parameters

  • num_patches - Total number of patches
  • mask_ratio - Fraction to mask (default: 0.75)

Returns

{visible_indices, masked_indices} tensors.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of the MAE encoder.

reconstruction_loss(reconstructed, original, masked_indices)

@spec reconstruction_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  Nx.Tensor.t()

Compute reconstruction loss on masked patches only.

Parameters

  • reconstructed - Decoder output: [batch, num_patches, input_dim]
  • original - Original input: [batch, num_patches, input_dim]
  • masked_indices - Indices of masked patches: [num_masked]

Returns

Scalar MSE loss over masked patches.