MAE - Masked Autoencoder.
Implements the Masked Autoencoder from "Masked Autoencoders Are Scalable Vision Learners" (He et al., CVPR 2022), adapted for 1D sequence data. MAE masks a large portion of input patches (tokens) and trains an autoencoder to reconstruct the missing patches.
Key Innovations
- High masking ratio: Masking 75% of patches creates a challenging pretext task that forces learning of strong representations
- Asymmetric encoder-decoder: Encoder only processes unmasked patches (efficient), decoder is lightweight and processes all patches
- Mask tokens: Learned placeholder tokens are added for masked positions before the decoder
Architecture
Input Patches [batch, num_patches, input_dim]
|
v
[Random Masking] (keep 25%)
|
v
+------------------+
| Encoder | (processes only unmasked patches)
| (deeper, wider) |
+------------------+
|
v
[Add Mask Tokens] (insert learnable tokens at masked positions)
|
v
+------------------+
| Decoder | (processes all patches)
| (shallow, narrow)|
+------------------+
|
v
[Reconstruction] MSE loss on masked patches onlyUsage
# Build encoder and decoder
{encoder, decoder} = MAE.build(
input_dim: 64,
embed_dim: 256,
decoder_dim: 128,
mask_ratio: 0.75,
num_encoder_layers: 4,
num_decoder_layers: 2
)
# For pretraining: use both encoder + decoder with masking
# For downstream: use encoder only (discard decoder)References
Summary
Functions
Build the MAE encoder and decoder.
Build the MAE decoder.
Build the MAE encoder.
Default decoder dimension
Default encoder embedding dimension
Default feedforward expansion factor
Default masking ratio
Default number of decoder layers
Default number of encoder layers
Generate a random mask for input patches.
Get the output size of the MAE encoder.
Compute reconstruction loss on masked patches only.
Types
@type build_opt() :: {:input_dim, pos_integer()} | {:embed_dim, pos_integer()} | {:num_encoder_layers, pos_integer()} | {:num_decoder_layers, pos_integer()} | {:decoder_dim, pos_integer()} | {:expand_factor, pos_integer()} | {:num_patches, pos_integer() | nil} | {:mask_ratio, float()}
Options for build/1.
Functions
Build the MAE encoder and decoder.
Options
:input_dim- Dimension of each input patch/token (required):embed_dim- Encoder embedding dimension (default: 256):decoder_dim- Decoder hidden dimension (default: 128):mask_ratio- Fraction of patches to mask (default: 0.75):num_encoder_layers- Number of encoder layers (default: 4):num_decoder_layers- Number of decoder layers (default: 2):num_patches- Number of input patches/tokens (default: nil for dynamic)
Returns
{encoder, decoder} tuple of Axon models.
Build the MAE decoder.
The decoder takes encoder output (with mask tokens inserted at masked positions) and reconstructs all patches.
Options
:input_dim- Original patch dimension for reconstruction target (required):embed_dim- Encoder output dimension / decoder input (default: 256):decoder_dim- Decoder hidden dimension (default: 128):num_decoder_layers- Number of decoder layers (default: 2):num_patches- Total number of patches (default: nil)
Returns
An Axon model: [batch, num_patches, embed_dim] -> [batch, num_patches, input_dim]
Build the MAE encoder.
The encoder processes only unmasked patches. In the full MAE pipeline, masking is applied externally before feeding to the encoder.
Options
:input_dim- Dimension of each input patch (required):embed_dim- Encoder embedding dimension (default: 256):num_encoder_layers- Number of encoder layers (default: 4):num_patches- Sequence length (default: nil)
Returns
An Axon model: [batch, num_visible, input_dim] -> [batch, num_visible, embed_dim]
@spec default_decoder_dim() :: pos_integer()
Default decoder dimension
@spec default_embed_dim() :: pos_integer()
Default encoder embedding dimension
@spec default_expand_factor() :: pos_integer()
Default feedforward expansion factor
@spec default_mask_ratio() :: float()
Default masking ratio
@spec default_num_decoder_layers() :: pos_integer()
Default number of decoder layers
@spec default_num_encoder_layers() :: pos_integer()
Default number of encoder layers
@spec generate_mask(non_neg_integer(), float()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
Generate a random mask for input patches.
Returns a tuple of {visible_indices, masked_indices} for a given
number of patches and mask ratio.
Parameters
num_patches- Total number of patchesmask_ratio- Fraction to mask (default: 0.75)
Returns
{visible_indices, masked_indices} tensors.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of the MAE encoder.
@spec reconstruction_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()
Compute reconstruction loss on masked patches only.
Parameters
reconstructed- Decoder output: [batch, num_patches, input_dim]original- Original input: [batch, num_patches, input_dim]masked_indices- Indices of masked patches: [num_masked]
Returns
Scalar MSE loss over masked patches.