Edifice.Interpretability.Transcoder (Edifice v0.2.0)

Copy Markdown View Source

Transcoder for cross-layer mechanistic interpretability.

Like a Sparse Autoencoder but maps between different layers' activation spaces. Input and output dimensions can differ, enabling analysis of how representations transform across layers.

Architecture

Input [batch, input_size]   (layer N activations)
      |
Encoder: dense(dict_size) + ReLU
      |
Sparsify: top-k
      |
[batch, dict_size]  (sparse cross-layer features)
      |
Decoder: dense(output_size)
      |
Output [batch, output_size]  (layer M activation prediction)

Usage

model = Transcoder.build(
  input_size: 256,
  output_size: 512,
  dict_size: 4096,
  top_k: 32
)

References

  • Dunefsky et al., "Transcoders Find Interpretable LLM Feature Circuits" (2024)

Summary

Types

Options for build/1.

Functions

Build a transcoder.

Compute transcoder training loss: reconstruction MSE + L1 sparsity penalty.

Get the output size of the transcoder.

Types

build_opt()

@type build_opt() ::
  {:input_size, pos_integer()}
  | {:output_size, pos_integer()}
  | {:dict_size, pos_integer()}
  | {:top_k, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a transcoder.

Options

  • :input_size - Source layer activation dimension (required)
  • :output_size - Target layer activation dimension (required)
  • :dict_size - Number of dictionary features (default: 4096)
  • :top_k - Number of active features (default: 32)

Returns

An Axon model mapping [batch, input_size] to [batch, output_size].

loss(target, reconstruction, hidden_acts, opts \\ [])

@spec loss(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) :: Nx.Tensor.t()

Compute transcoder training loss: reconstruction MSE + L1 sparsity penalty.

Parameters

  • target - Target layer activations [batch, output_size]
  • reconstruction - Transcoder output [batch, output_size]
  • hidden_acts - Sparse hidden activations [batch, dict_size]
  • opts - Options:
    • :l1_coeff - L1 penalty coefficient (default: 1.0e-3)

Returns

Scalar loss tensor.

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of the transcoder.