# `Edifice.Contrastive.MAE`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/contrastive/mae.ex#L1)

MAE - Masked Autoencoder.

Implements the Masked Autoencoder from "Masked Autoencoders Are Scalable
Vision Learners" (He et al., CVPR 2022), adapted for 1D sequence data.
MAE masks a large portion of input patches (tokens) and trains an
autoencoder to reconstruct the missing patches.

## Key Innovations

- **High masking ratio**: Masking 75% of patches creates a challenging
  pretext task that forces learning of strong representations
- **Asymmetric encoder-decoder**: Encoder only processes unmasked patches
  (efficient), decoder is lightweight and processes all patches
- **Mask tokens**: Learned placeholder tokens are added for masked positions
  before the decoder

## Architecture

```
Input Patches [batch, num_patches, input_dim]
      |
      v
[Random Masking] (keep 25%)
      |
      v
+------------------+
|     Encoder      |  (processes only unmasked patches)
| (deeper, wider)  |
+------------------+
      |
      v
[Add Mask Tokens]  (insert learnable tokens at masked positions)
      |
      v
+------------------+
|     Decoder      |  (processes all patches)
| (shallow, narrow)|
+------------------+
      |
      v
[Reconstruction]   MSE loss on masked patches only
```

## Usage

    # Build encoder and decoder
    {encoder, decoder} = MAE.build(
      input_dim: 64,
      embed_dim: 256,
      decoder_dim: 128,
      mask_ratio: 0.75,
      num_encoder_layers: 4,
      num_decoder_layers: 2
    )

    # For pretraining: use both encoder + decoder with masking
    # For downstream: use encoder only (discard decoder)

## References
- Paper: https://arxiv.org/abs/2111.06377

# `build_opt`

```elixir
@type build_opt() ::
  {:input_dim, pos_integer()}
  | {:embed_dim, pos_integer()}
  | {:num_encoder_layers, pos_integer()}
  | {:num_decoder_layers, pos_integer()}
  | {:decoder_dim, pos_integer()}
  | {:expand_factor, pos_integer()}
  | {:num_patches, pos_integer() | nil}
  | {:mask_ratio, float()}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: {Axon.t(), Axon.t()}
```

Build the MAE encoder and decoder.

## Options
  - `:input_dim` - Dimension of each input patch/token (required)
  - `:embed_dim` - Encoder embedding dimension (default: 256)
  - `:decoder_dim` - Decoder hidden dimension (default: 128)
  - `:mask_ratio` - Fraction of patches to mask (default: 0.75)
  - `:num_encoder_layers` - Number of encoder layers (default: 4)
  - `:num_decoder_layers` - Number of decoder layers (default: 2)
  - `:num_patches` - Number of input patches/tokens (default: nil for dynamic)

## Returns
  `{encoder, decoder}` tuple of Axon models.

# `build_decoder`

```elixir
@spec build_decoder(keyword()) :: Axon.t()
```

Build the MAE decoder.

The decoder takes encoder output (with mask tokens inserted at masked
positions) and reconstructs all patches.

## Options
  - `:input_dim` - Original patch dimension for reconstruction target (required)
  - `:embed_dim` - Encoder output dimension / decoder input (default: 256)
  - `:decoder_dim` - Decoder hidden dimension (default: 128)
  - `:num_decoder_layers` - Number of decoder layers (default: 2)
  - `:num_patches` - Total number of patches (default: nil)

## Returns
  An Axon model: [batch, num_patches, embed_dim] -> [batch, num_patches, input_dim]

# `build_encoder`

```elixir
@spec build_encoder(keyword()) :: Axon.t()
```

Build the MAE encoder.

The encoder processes only unmasked patches. In the full MAE pipeline,
masking is applied externally before feeding to the encoder.

## Options
  - `:input_dim` - Dimension of each input patch (required)
  - `:embed_dim` - Encoder embedding dimension (default: 256)
  - `:num_encoder_layers` - Number of encoder layers (default: 4)
  - `:num_patches` - Sequence length (default: nil)

## Returns
  An Axon model: [batch, num_visible, input_dim] -> [batch, num_visible, embed_dim]

# `default_decoder_dim`

```elixir
@spec default_decoder_dim() :: pos_integer()
```

Default decoder dimension

# `default_embed_dim`

```elixir
@spec default_embed_dim() :: pos_integer()
```

Default encoder embedding dimension

# `default_expand_factor`

```elixir
@spec default_expand_factor() :: pos_integer()
```

Default feedforward expansion factor

# `default_mask_ratio`

```elixir
@spec default_mask_ratio() :: float()
```

Default masking ratio

# `default_num_decoder_layers`

```elixir
@spec default_num_decoder_layers() :: pos_integer()
```

Default number of decoder layers

# `default_num_encoder_layers`

```elixir
@spec default_num_encoder_layers() :: pos_integer()
```

Default number of encoder layers

# `generate_mask`

```elixir
@spec generate_mask(non_neg_integer(), float()) :: {Nx.Tensor.t(), Nx.Tensor.t()}
```

Generate a random mask for input patches.

Returns a tuple of `{visible_indices, masked_indices}` for a given
number of patches and mask ratio.

## Parameters
  - `num_patches` - Total number of patches
  - `mask_ratio` - Fraction to mask (default: 0.75)

## Returns
  `{visible_indices, masked_indices}` tensors.

# `output_size`

```elixir
@spec output_size(keyword()) :: non_neg_integer()
```

Get the output size of the MAE encoder.

# `reconstruction_loss`

```elixir
@spec reconstruction_loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) ::
  Nx.Tensor.t()
```

Compute reconstruction loss on masked patches only.

## Parameters
  - `reconstructed` - Decoder output: [batch, num_patches, input_dim]
  - `original` - Original input: [batch, num_patches, input_dim]
  - `masked_indices` - Indices of masked patches: [num_masked]

## Returns
  Scalar MSE loss over masked patches.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
