Edifice.Memory.NTM (Edifice v0.2.0)

Copy Markdown View Source

Neural Turing Machine (Graves et al., 2014).

An NTM augments a neural network controller with an external memory matrix that can be read from and written to via differentiable attention mechanisms. This enables learning algorithms like copying, sorting, and associative recall.

Architecture

Input [batch, input_size]
      |
      +------------------+
      |                  |
      v                  v
+------------+    +-----------+
| Controller |    |  Memory   |
|   (LSTM)   |    | [N x M]  |
+------------+    +-----------+
      |                ^  |
      +--+--+          |  |
      |  |  |          |  |
      v  v  v          |  |
    Read Write    Read/ Write
    Head  Head    Addressing
      |    |           |
      +----+-----------+
      |
      v
Output [batch, output_size]

Addressing Mechanism

The NTM uses a 4-stage addressing pipeline for each head:

  1. Content addressing: Cosine similarity between controller key and memory rows, scaled by sharpness parameter beta → softmax
  2. Interpolation: w = g * w_content + (1-g) * w_prev blends content-based weights with previous location weights
  3. Circular shift: Convolves weights with a learned 3-element kernel [shift_left, stay, shift_right] to move focus
  4. Sharpening: w = w^gamma / sum(w^gamma) concentrates the distribution (gamma >= 1 prevents blurring)

Write Mechanism

The write head updates memory via erase-then-add:

M_new = M * (1 - w * e^T) + w * a^T

where w is the address weights, e is the erase vector (sigmoid, [0,1]), and a is the add vector.

Usage

model = NTM.build(
  input_size: 64,
  memory_size: 128,
  memory_dim: 32,
  controller_size: 256,
  num_heads: 1
)

References

Summary

Types

Options for build/1.

Functions

Build a Neural Turing Machine.

Build the LSTM controller that drives the read/write heads.

Content-based addressing using cosine similarity.

Compute read head: full addressing pipeline + weighted read from memory.

Compute write head: full addressing pipeline + erase/add memory update.

Types

build_opt()

@type build_opt() ::
  {:controller_size, pos_integer()}
  | {:input_size, pos_integer()}
  | {:memory_dim, pos_integer()}
  | {:memory_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:output_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Neural Turing Machine.

The NTM consists of:

  • An LSTM controller that processes inputs and generates head parameters
  • A differentiable memory matrix accessed via read and write heads
  • Full 4-stage addressing: content → interpolation → shift → sharpening

Options

  • :input_size - Input feature dimension (required)
  • :memory_size - Number of memory rows N (default: 128)
  • :memory_dim - Dimension of each memory row M (default: 32)
  • :controller_size - LSTM controller hidden size (default: 256)
  • :num_heads - Number of read/write heads (default: 1)
  • :output_size - Output dimension (default: same as input_size)

Returns

An Axon model taking input [batch, input_size] and memory [batch, N, M], producing output [batch, output_size].

build_controller(input, controller_size)

@spec build_controller(Axon.t(), pos_integer()) :: Axon.t()

Build the LSTM controller that drives the read/write heads.

The controller processes the combined input (external input + previous read vectors) and produces a hidden state used to parameterize the head operations.

Parameters

  • input - Axon node with combined input [batch, combined_dim]
  • controller_size - Hidden dimension for the controller

Returns

An Axon node with shape [batch, controller_size]

content_addressing(key, memory, beta)

@spec content_addressing(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t()) :: Nx.Tensor.t()

Content-based addressing using cosine similarity.

Computes attention weights over memory rows based on cosine similarity between a query key and each memory row, scaled by sharpness beta.

w_i = softmax(beta * cosine_similarity(key, memory[i]))

Parameters

  • key - Query key [batch, M]
  • memory - Memory matrix [batch, N, M]
  • beta - Sharpness parameter [batch, 1]

Returns

Attention weights [batch, N]

read_head(controller_out, memory, opts \\ [])

@spec read_head(Axon.t(), Axon.t(), keyword()) :: Axon.t()

Compute read head: full addressing pipeline + weighted read from memory.

Uses the 4-stage addressing pipeline (content → interpolation → shift → sharpening) to compute read weights, then reads a weighted sum from memory.

Parameters

  • controller_out - Controller hidden state [batch, controller_size]
  • memory - Memory matrix [batch, N, M]

Options

  • :memory_size - Number of memory rows N
  • :memory_dim - Dimension of each memory row M
  • :name - Layer name prefix

Returns

Read vector [batch, M]

write_head(controller_out, memory, opts \\ [])

@spec write_head(Axon.t(), Axon.t(), keyword()) :: Axon.t()

Compute write head: full addressing pipeline + erase/add memory update.

Uses the 4-stage addressing pipeline to compute write weights, then updates memory via the erase-then-add mechanism:

M_new = M * (1 - w * e^T) + w * a^T

Parameters

  • controller_out - Controller hidden state [batch, controller_size]
  • memory - Memory matrix [batch, N, M]

Options

  • :memory_size - Number of memory rows N
  • :memory_dim - Dimension of each memory row M
  • :name - Layer name prefix

Returns

Updated memory [batch, N, M]