Edifice.Energy.Hopfield (Edifice v0.2.0)

Copy Markdown View Source

Modern Continuous Hopfield Network (Ramsauer et al., 2020).

Classical Hopfield networks store binary patterns and recall them via energy minimization. Modern Hopfield networks replace the quadratic energy with an exponential interaction function, yielding:

  1. Exponentially many stored patterns (vs polynomial in classical)
  2. Single-step convergence for retrieval
  3. Mathematical equivalence to attention with softmax

Key Insight

The update rule softmax(beta * X * Y^T) * Y is exactly the attention mechanism with query X, key Y, value Y, and inverse temperature beta. Higher beta -> sharper retrieval (more like nearest neighbor). Lower beta -> softer retrieval (more like averaging).

Architecture

Query X [batch, seq_len, input_dim]
     |
     v
+----------------------------+
|  Similarity: beta * X * Y^T |
+----------------------------+
     |
     v
+----------------------------+
|       softmax(scores)       |
+----------------------------+
     |
     v
+----------------------------+
|    Retrieval: weights * Y   |
+----------------------------+
     |
     v
Output [batch, seq_len, pattern_dim]

Usage

# Build a Hopfield layer
model = Hopfield.build(input_dim: 128, num_patterns: 64, pattern_dim: 128)

# Build an associative memory
model = Hopfield.build_associative_memory(
  input_dim: 256,
  num_patterns: 128,
  pattern_dim: 256,
  beta: 2.0,
  num_heads: 4
)

References

Summary

Types

Options for build/1.

Functions

Build a Modern Hopfield layer as an Axon model.

Build a Hopfield-based associative memory network.

Compute the Hopfield energy for a state and stored patterns.

Apply a single Hopfield attention layer.

Types

build_opt()

@type build_opt() ::
  {:beta, float()}
  | {:input_dim, pos_integer()}
  | {:num_patterns, pos_integer()}
  | {:pattern_dim, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Modern Hopfield layer as an Axon model.

Stores learnable patterns and retrieves them via exponential similarity (equivalent to attention).

Options

  • :input_dim - Input feature dimension (required)
  • :num_patterns - Number of stored patterns N (default: 64)
  • :pattern_dim - Dimension of each pattern M (default: 128)
  • :beta - Inverse temperature for softmax sharpness (default: 1.0)

Returns

An Axon model: [batch, input_dim] -> [batch, pattern_dim]

build_associative_memory(opts \\ [])

@spec build_associative_memory(keyword()) :: Axon.t()

Build a Hopfield-based associative memory network.

Multi-layer architecture with multiple Hopfield heads for robust pattern storage and retrieval.

Options

  • :input_dim - Input feature dimension (required)
  • :num_patterns - Number of stored patterns per head (default: 64)
  • :pattern_dim - Dimension of each pattern (default: 128)
  • :beta - Inverse temperature (default: 1.0)
  • :num_heads - Number of parallel Hopfield heads (default: 1)
  • :hidden_size - Hidden dimension for projection layers (default: 256)
  • :num_layers - Number of Hopfield layers (default: 2)
  • :dropout - Dropout rate (default: 0.1)

Returns

An Axon model: [batch, input_dim] -> [batch, hidden_size]

energy(query, patterns, beta)

@spec energy(Nx.Tensor.t(), Nx.Tensor.t(), float()) :: Nx.Tensor.t()

Compute the Hopfield energy for a state and stored patterns.

Energy: E(x) = -beta log(sum_i exp(beta x^T * y_i))

Lower energy = better match to stored patterns. This is a pure numerical function for analysis/debugging.

Parameters

  • query - Query state [batch, dim]
  • patterns - Stored patterns [num_patterns, dim]
  • beta - Inverse temperature

Returns

Energy values [batch]

hopfield_layer(input, opts \\ [])

@spec hopfield_layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Apply a single Hopfield attention layer.

Computes: softmax(beta * X * Y^T) * Y

where Y are stored (learnable) patterns. The patterns are implemented as a dense projection to num_patterns (computing similarity scores), followed by a dense projection from num_patterns to pattern_dim (weighted retrieval).

Parameters

  • input - Axon node with shape [batch, input_dim] or [batch, seq_len, input_dim]

Options

  • :num_patterns - Number of stored patterns (default: 64)
  • :pattern_dim - Dimension of each pattern (default: 128)
  • :beta - Inverse temperature (default: 1.0)
  • :name - Layer name prefix (default: "hopfield")

Returns

An Axon node with shape [batch, pattern_dim] or [batch, seq_len, pattern_dim]