Sparse Autoencoder (SAE) for mechanistic interpretability.
Learns a sparse, overcomplete dictionary of features from neural network activations. Used to decompose model internals into interpretable directions.
Architecture
Input [batch, input_size]
|
Encoder: dense(dict_size) + ReLU
|
Sparsify: top-k or L1 penalty
|
[batch, dict_size] (sparse activations)
|
Decoder: dense(input_size)
|
Output [batch, input_size] (reconstruction)Sparsity Modes
:top_k— Zero out all but the top-k activations per sample (hard sparsity):l1— No hard sparsity in the forward pass; useloss/4with L1 penalty
Usage
model = SparseAutoencoder.build(
input_size: 256,
dict_size: 4096,
top_k: 32,
sparsity: :top_k
)
# Training loss includes reconstruction + L1 penalty
loss = SparseAutoencoder.loss(input, reconstruction, hidden_acts, l1_coeff: 1.0e-3)References
- Bricken et al., "Towards Monosemanticity" (Anthropic, 2023)
- Cunningham et al., "Sparse Autoencoders Find Highly Interpretable Features" (2023)
Summary
Functions
Build a sparse autoencoder.
Build the encoder portion only (for extracting sparse activations).
Compute SAE training loss: reconstruction MSE + L1 sparsity penalty.
Get the output size of the SAE (same as input_size).
Types
@type build_opt() :: {:input_size, pos_integer()} | {:dict_size, pos_integer()} | {:top_k, pos_integer()} | {:sparsity, :top_k | :l1}
Options for build/1.
Functions
Build a sparse autoencoder.
Options
:input_size- Dimension of input activations (required):dict_size- Number of dictionary features, typically >> input_size (default: 4096):top_k- Number of active features whensparsity: :top_k(default: 32):sparsity- Sparsity mode::top_kor:l1(default::top_k)
Returns
An Axon model mapping [batch, input_size] to [batch, input_size].
Build the encoder portion only (for extracting sparse activations).
Returns the sparse hidden activations [batch, dict_size].
@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()
Compute SAE training loss: reconstruction MSE + L1 sparsity penalty.
Parameters
input- Original activations[batch, input_size]reconstruction- SAE output[batch, input_size]hidden_acts- Sparse hidden activations[batch, dict_size]opts- Options::l1_coeff- L1 penalty coefficient (default: 1.0e-3)
Returns
Scalar loss tensor.
@spec output_size(keyword()) :: pos_integer()
Get the output size of the SAE (same as input_size).