Edifice.Blocks.CausalMask (Edifice v0.2.0)

Copy Markdown View Source

Causal and window attention mask utilities.

Provides unified mask creation functions used across attention modules (MultiHead, GQA, InfiniAttention, RingAttention, etc.).

Mask Types

  • Causal: Lower-triangular — each position attends only to itself and earlier positions
  • Window: Causal + limited lookback window
  • Block diagonal: For chunked/ring attention patterns

All masks are boolean tensors where true = attend, false = mask out.

Usage

mask = CausalMask.causal(64)
window = CausalMask.window(64, 16)

# Copy to BinaryBackend for use inside Axon.nx closures
mask = CausalMask.to_binary_backend(mask)

Summary

Functions

Create a causal (autoregressive) attention mask.

Copy a mask tensor to Nx.BinaryBackend.

Create a sliding window attention mask.

Functions

causal(seq_len)

@spec causal(non_neg_integer()) :: Nx.Tensor.t()

Create a causal (autoregressive) attention mask.

Returns a boolean tensor of shape [seq_len, seq_len] where position i can attend to positions 0..i.

Examples

iex> mask = Edifice.Blocks.CausalMask.causal(4)
iex> Nx.shape(mask)
{4, 4}

to_binary_backend(mask)

@spec to_binary_backend(Nx.Tensor.t()) :: Nx.Tensor.t()

Copy a mask tensor to Nx.BinaryBackend.

Required when capturing masks in Axon.nx closures to avoid EXLA/Defn.Expr backend mismatch during JIT compilation.

window(seq_len, window_size)

@spec window(non_neg_integer(), non_neg_integer()) :: Nx.Tensor.t()

Create a sliding window attention mask.

Each position attends to at most window_size preceding positions (including itself). Combines causal constraint with window constraint.

Examples

iex> mask = Edifice.Blocks.CausalMask.window(8, 3)
iex> Nx.shape(mask)
{8, 8}