Engram: O(1) Hash-Based Associative Memory via Locality-Sensitive Hashing.
Engram implements a fast key-value memory that stores and retrieves values in amortised O(1) time using Locality-Sensitive Hashing (LSH). Multiple independent hash tables reduce collision probability, and exponential moving-average (EMA) writes allow smooth interpolation of stored values.
Motivation
Classical associative memories (NTM, MemoryNetwork) perform O(N) attention over all N memory slots per query. Engram instead hashes each query into a small number of buckets, making both reads and writes O(hash_bits × key_dim) independent of the number of stored memories.
Architecture
Query key [batch, key_dim]
|
v
+---------------------------+
| LSH Hash (per table t): |
| project = W_t @ key | W_t ∈ R^{hash_bits × key_dim}
| bits = sign(project) | {0,1}^hash_bits
| bucket = bits → int | 0..num_buckets-1
+---------------------------+
|
v (for each of num_tables tables)
+---------------------------+
| Memory Slots |
| [num_tables, num_buckets, value_dim]
+---------------------------+
|
v
Retrieve slot per table → average across tables
|
v
Retrieved value [batch, value_dim]Hashing
For num_buckets = 256 = 2^8:
hash_bits = 8random projection vectors per tableWshape:[num_tables, hash_bits, key_dim]bucket = sum(sign(W @ key) >= 0) * [1, 2, 4, ..., 2^{hash_bits-1}]
Write (EMA update)
memory[t, hash(key)] ← decay × memory[t, hash(key)] + (1 − decay) × valueUsage
# Create memory state
mem = Engram.new(key_dim: 32, value_dim: 64)
# Read
result = Engram.engram_read(mem, query) # [batch, value_dim]
# Write (returns updated memory)
mem = Engram.engram_write(mem, key, value, decay: 0.99)
# Build an Axon model for differentiable reads
model = Engram.build(key_dim: 32, value_dim: 64)References
- Andoni & Indyk, "Near-Optimal Hashing Algorithms for Approximate Nearest Neighbor in High Dimensions" (FOCS 2006)
- Locality-Sensitive Hashing for associative memory is explored in "Reformer: The Efficient Transformer" (Kitaev et al., 2020)
Summary
Functions
Build an Axon model for differentiable Engram reads.
Read from Engram memory using LSH-based lookup.
Write a key-value pair into Engram memory using EMA updates.
Initialise a fresh Engram memory state.
Types
@type build_opt() :: {:key_dim, pos_integer()} | {:num_buckets, pos_integer()} | {:num_tables, pos_integer()} | {:value_dim, pos_integer()}
Options for build/1.
Functions
Build an Axon model for differentiable Engram reads.
The model takes a query and the current memory slots as inputs; the LSH projection matrices are trainable parameters (useful when the hash function itself should be optimised end-to-end).
Options
:key_dim- Query/key dimension (required):value_dim- Value / slot dimension (required):num_buckets- Number of hash buckets; must be a power of 2 (default: 256):num_tables- Number of independent hash tables (default: 4)
Returns
An Axon model with inputs:
"query"—[batch, key_dim]"memory_slots"—[num_tables, num_buckets, value_dim]
Returns retrieved values [batch, value_dim].
@spec engram_read( %{hash_matrices: Nx.Tensor.t(), slots: Nx.Tensor.t()}, Nx.Tensor.t() ) :: Nx.Tensor.t()
Read from Engram memory using LSH-based lookup.
Parameters
memory- Memory state map with:hash_matricesand:slotsquery- Query tensor[batch, key_dim]or[key_dim]
Returns
Retrieved value [batch, value_dim] (averaged across tables).
@spec engram_write( %{hash_matrices: Nx.Tensor.t(), slots: Nx.Tensor.t()}, Nx.Tensor.t(), Nx.Tensor.t(), keyword() ) :: %{hash_matrices: Nx.Tensor.t(), slots: Nx.Tensor.t()}
Write a key-value pair into Engram memory using EMA updates.
Each hash table independently maps the key to a bucket and applies:
slot[t, bucket] ← decay × slot[t, bucket] + (1 − decay) × valueParameters
memory- Memory state map with:hash_matricesand:slotskey- Key tensor[key_dim]or[1, key_dim]value- Value tensor[value_dim]or[1, value_dim]
Options
:decay- EMA decay coefficient (default: 0.99)
Returns
Updated memory state map.
@spec new([build_opt() | {:seed, non_neg_integer()}]) :: %{ hash_matrices: Nx.Tensor.t(), slots: Nx.Tensor.t() }
Initialise a fresh Engram memory state.
Creates random LSH projection matrices (normalised) and zero-filled memory slots.
Options
:key_dim- Key dimension (required):value_dim- Value dimension (required):num_buckets- Number of hash buckets (default: 256):num_tables- Number of hash tables (default: 4):seed- Random seed for reproducibility (default: 0)
Returns
A map %{hash_matrices: Nx.Tensor.t(), slots: Nx.Tensor.t()}.