Edifice.Attention.TMRoPE (Edifice v0.2.0)

Copy Markdown View Source

TMRoPE: Time-aligned Multimodal RoPE for unified position encoding across modalities.

Extends RoPE to align different modalities (text, image, video) by their temporal position rather than sequence position. This enables coherent cross-modal attention where tokens from the same moment in time share compatible positional encodings.

Key Insight

Standard RoPE assigns positions sequentially (0, 1, 2, ...). For multimodal inputs, this breaks semantic alignment: an image at position 100 might represent the same moment as text at position 50. TMRoPE instead assigns positions based on timestamps:

Text tokens:     "A cat"     -> positions [0.0, 0.0]  (same utterance)
Image patches:   [64 patches] -> positions [1.0, 1.0, ..., 1.0]  (same frame)
Video frames:    [3 frames]   -> positions [2.0, 2.5, 3.0]  (temporal sequence)

Position Assignment

Input: [text_tokens, image_patches, video_frames]
       |
+------v-----------------------------------+
| Modality Metadata                        |
|   {:text,  [start: 0, end: 5]}           |  -> positions = [0.0] * 5
|   {:image, [start: 5, end: 69, time: 1]} |  -> positions = [1.0] * 64
|   {:video, [start: 69, end: 261,         |  -> positions = [2.0, 2.5, 3.0, ...]
|             frame_times: [2,2.5,3,...]]} |     per-frame assignment
+------------------------------------------+
       |
       v
Position IDs: [0,0,0,0,0, 1,1,...,1, 2,2.5,3,...]
       |
+------v-----------------------------------+
| RoPE with temporal positions             |
+------------------------------------------+
       |
       v
Time-aligned Q, K tensors

Formula

For each token at temporal position t:

  • θᵢ = base^(-2i/d) (standard RoPE frequencies)
  • Rotation angle = t × θᵢ (using temporal position, not sequence index)
  • Optional temporal scaling: θᵢ' = θᵢ × temporal_scaling

Usage

# Build TMRoPE-wrapped attention
model = TMRoPE.build(
  embed_dim: 64,
  modalities: [:text, :image, :video],
  temporal_scaling: 1.0
)

# Assign positions to multimodal sequence
metadata = [
  {:text, [start_idx: 0, end_idx: 5, time: 0.0]},
  {:image, [start_idx: 5, end_idx: 69, time: 1.0]},
  {:video, [start_idx: 69, end_idx: 261, frame_times: [2.0, 2.5, 3.0]]}
]
position_ids = TMRoPE.assign_positions(seq_len, metadata)

# Apply TMRoPE to Q, K
{q_rot, k_rot} = TMRoPE.apply_tmrope(q, k, position_ids)

References

Summary

Types

Modality metadata for position assignment.

Options for TMRoPE functions.

Functions

Apply TMRoPE to query and key tensors using temporal position IDs.

Assign temporal position IDs to a multimodal sequence.

Build an Axon model that applies TMRoPE to query/key inputs.

Generate sequential frame times for video.

Calculate output size (same as input embed_dim).

Get recommended defaults for TMRoPE.

Compute TMRoPE frequency table with temporal scaling.

Types

modality_metadata()

@type modality_metadata() ::
  {:text, keyword()}
  | {:image, keyword()}
  | {:video, keyword()}
  | {:audio, keyword()}

Modality metadata for position assignment.

tmrope_opt()

@type tmrope_opt() ::
  {:embed_dim, pos_integer()}
  | {:modalities, [:text | :image | :video | :audio]}
  | {:max_position, pos_integer()}
  | {:temporal_scaling, number()}
  | {:base, number()}
  | {:name, String.t()}

Options for TMRoPE functions.

Functions

apply_tmrope(query, key, position_ids, opts \\ [])

@spec apply_tmrope(Nx.Tensor.t(), Nx.Tensor.t(), Nx.Tensor.t(), keyword()) ::
  {Nx.Tensor.t(), Nx.Tensor.t()}

Apply TMRoPE to query and key tensors using temporal position IDs.

Parameters

  • query - Query tensor [batch, seq_len, embed_dim]
  • key - Key tensor [batch, seq_len, embed_dim]
  • position_ids - Temporal positions [batch, seq_len] or [seq_len]
  • opts - Options:
    • :temporal_scaling - Frequency scaling (default: 1.0)
    • :base - RoPE base frequency (default: 10000.0)

Returns

{rotated_query, rotated_key} with same shapes as input.

Example

positions = TMRoPE.assign_positions(seq_len, metadata)
{q_rot, k_rot} = TMRoPE.apply_tmrope(q, k, positions)

assign_positions(seq_len, modality_metadata)

@spec assign_positions(pos_integer(), [modality_metadata()]) :: Nx.Tensor.t()

Assign temporal position IDs to a multimodal sequence.

Maps token indices to time-aligned positions based on modality metadata. Tokens from the same temporal moment (e.g., same video frame) share positions.

Parameters

  • seq_len - Total sequence length
  • modality_metadata - List of modality specifications:
    • {:text, [start_idx: int, end_idx: int, time: float]} — all tokens get same time
    • {:image, [start_idx: int, end_idx: int, time: float]} — all patches get same time
    • {:video, [start_idx: int, end_idx: int, patches_per_frame: int, frame_times: [float]]} — patches grouped by frame

Returns

Tensor of shape [seq_len] with temporal position IDs.

Example

metadata = [
  {:text, [start_idx: 0, end_idx: 10, time: 0.0]},
  {:image, [start_idx: 10, end_idx: 74, time: 1.0]},
  {:video, [start_idx: 74, end_idx: 266, patches_per_frame: 64, frame_times: [2.0, 3.0, 4.0]]}
]
positions = TMRoPE.assign_positions(266, metadata)

build(opts \\ [])

@spec build([tmrope_opt()]) :: Axon.t()

Build an Axon model that applies TMRoPE to query/key inputs.

Options

  • :embed_dim - Feature dimension (required, must be even)
  • :modalities - List of modality types (default: [:text, :image, :video])
  • :max_position - Maximum temporal position (default: 32768)
  • :temporal_scaling - Scaling factor for temporal frequencies (default: 1.0)
  • :base - RoPE base frequency (default: 10000.0)
  • :name - Layer name prefix (default: "tmrope")

Inputs

The model takes three inputs:

  • "tmrope_query" — Query tensor [batch, seq_len, embed_dim]
  • "tmrope_key" — Key tensor [batch, seq_len, embed_dim]
  • "tmrope_positions" — Position IDs [batch, seq_len] (temporal positions)

Returns

An Axon container with {:query, :key} rotated tensors.

frame_times(num_frames, opts \\ [])

@spec frame_times(
  pos_integer(),
  keyword()
) :: [float()]

Generate sequential frame times for video.

Parameters

  • num_frames - Number of video frames
  • opts - Options:
    • :start_time - Starting time (default: 0.0)
    • :frame_interval - Time between frames (default: 1.0)

Example

TMRoPE.frame_times(3, start_time: 2.0, frame_interval: 0.5)
# => [2.0, 2.5, 3.0]

output_size(opts)

@spec output_size(keyword()) :: pos_integer()

Calculate output size (same as input embed_dim).

tmrope_freqs(embed_dim, opts \\ [])

@spec tmrope_freqs(
  pos_integer(),
  keyword()
) :: Nx.Tensor.t()

Compute TMRoPE frequency table with temporal scaling.

Parameters

  • embed_dim - Feature dimension (must be even)
  • opts - Options:
    • :temporal_scaling - Frequency scaling factor (default: 1.0)
    • :base - RoPE base frequency (default: 10000.0)

Returns

Tensor of shape [embed_dim / 2] with scaled frequencies.

Example

freqs = TMRoPE.tmrope_freqs(64, temporal_scaling: 0.5)