# `Edifice.Interpretability.SparseAutoencoder`
[🔗](https://github.com/blasphemetheus/edifice/blob/main/lib/edifice/interpretability/sparse_autoencoder.ex#L1)

Sparse Autoencoder (SAE) for mechanistic interpretability.

Learns a sparse, overcomplete dictionary of features from neural network
activations. Used to decompose model internals into interpretable directions.

## Architecture

```
Input [batch, input_size]
      |
Encoder: dense(dict_size) + ReLU
      |
Sparsify: top-k or L1 penalty
      |
[batch, dict_size]  (sparse activations)
      |
Decoder: dense(input_size)
      |
Output [batch, input_size]  (reconstruction)
```

## Sparsity Modes

- `:top_k` — Zero out all but the top-k activations per sample (hard sparsity)
- `:l1` — No hard sparsity in the forward pass; use `loss/4` with L1 penalty

## Usage

    model = SparseAutoencoder.build(
      input_size: 256,
      dict_size: 4096,
      top_k: 32,
      sparsity: :top_k
    )

    # Training loss includes reconstruction + L1 penalty
    loss = SparseAutoencoder.loss(input, reconstruction, hidden_acts, l1_coeff: 1.0e-3)

## References

- Bricken et al., "Towards Monosemanticity" (Anthropic, 2023)
- Cunningham et al., "Sparse Autoencoders Find Highly Interpretable Features" (2023)

# `build_opt`

```elixir
@type build_opt() ::
  {:input_size, pos_integer()}
  | {:dict_size, pos_integer()}
  | {:top_k, pos_integer()}
  | {:sparsity, :top_k | :l1}
```

Options for `build/1`.

# `build`

```elixir
@spec build([build_opt()]) :: Axon.t()
```

Build a sparse autoencoder.

## Options

  - `:input_size` - Dimension of input activations (required)
  - `:dict_size` - Number of dictionary features, typically >> input_size (default: 4096)
  - `:top_k` - Number of active features when `sparsity: :top_k` (default: 32)
  - `:sparsity` - Sparsity mode: `:top_k` or `:l1` (default: `:top_k`)

## Returns

  An Axon model mapping `[batch, input_size]` to `[batch, input_size]`.

# `build_encoder`

```elixir
@spec build_encoder([build_opt()]) :: Axon.t()
```

Build the encoder portion only (for extracting sparse activations).

Returns the sparse hidden activations `[batch, dict_size]`.

# `loss`

```elixir
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
```

Compute SAE training loss: reconstruction MSE + L1 sparsity penalty.

## Parameters

  - `input` - Original activations `[batch, input_size]`
  - `reconstruction` - SAE output `[batch, input_size]`
  - `hidden_acts` - Sparse hidden activations `[batch, dict_size]`
  - `opts` - Options:
    - `:l1_coeff` - L1 penalty coefficient (default: 1.0e-3)

## Returns

  Scalar loss tensor.

# `output_size`

```elixir
@spec output_size(keyword()) :: pos_integer()
```

Get the output size of the SAE (same as input_size).

---

*Consult [api-reference.md](api-reference.md) for complete listing*
