Edifice.Interpretability.SparseAutoencoder (Edifice v0.2.0)

Copy Markdown View Source

Sparse Autoencoder (SAE) for mechanistic interpretability.

Learns a sparse, overcomplete dictionary of features from neural network activations. Used to decompose model internals into interpretable directions.

Architecture

Input [batch, input_size]
      |
Encoder: dense(dict_size) + ReLU
      |
Sparsify: top-k or L1 penalty
      |
[batch, dict_size]  (sparse activations)
      |
Decoder: dense(input_size)
      |
Output [batch, input_size]  (reconstruction)

Sparsity Modes

  • :top_k — Zero out all but the top-k activations per sample (hard sparsity)
  • :l1 — No hard sparsity in the forward pass; use loss/4 with L1 penalty

Usage

model = SparseAutoencoder.build(
  input_size: 256,
  dict_size: 4096,
  top_k: 32,
  sparsity: :top_k
)

# Training loss includes reconstruction + L1 penalty
loss = SparseAutoencoder.loss(input, reconstruction, hidden_acts, l1_coeff: 1.0e-3)

References

  • Bricken et al., "Towards Monosemanticity" (Anthropic, 2023)
  • Cunningham et al., "Sparse Autoencoders Find Highly Interpretable Features" (2023)

Summary

Types

Options for build/1.

Functions

Build a sparse autoencoder.

Build the encoder portion only (for extracting sparse activations).

Compute SAE training loss: reconstruction MSE + L1 sparsity penalty.

Get the output size of the SAE (same as input_size).

Types

build_opt()

@type build_opt() ::
  {:input_size, pos_integer()}
  | {:dict_size, pos_integer()}
  | {:top_k, pos_integer()}
  | {:sparsity, :top_k | :l1}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a sparse autoencoder.

Options

  • :input_size - Dimension of input activations (required)
  • :dict_size - Number of dictionary features, typically >> input_size (default: 4096)
  • :top_k - Number of active features when sparsity: :top_k (default: 32)
  • :sparsity - Sparsity mode: :top_k or :l1 (default: :top_k)

Returns

An Axon model mapping [batch, input_size] to [batch, input_size].

build_encoder(opts \\ [])

@spec build_encoder([build_opt()]) :: Axon.t()

Build the encoder portion only (for extracting sparse activations).

Returns the sparse hidden activations [batch, dict_size].

loss(input, reconstruction, hidden_acts, opts \\ [])

@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute SAE training loss: reconstruction MSE + L1 sparsity penalty.

Parameters

  • input - Original activations [batch, input_size]
  • reconstruction - SAE output [batch, input_size]
  • hidden_acts - Sparse hidden activations [batch, dict_size]
  • opts - Options:
    • :l1_coeff - L1 penalty coefficient (default: 1.0e-3)

Returns

Scalar loss tensor.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the SAE (same as input_size).