Causal and window attention mask utilities.
Provides unified mask creation functions used across attention modules (MultiHead, GQA, InfiniAttention, RingAttention, etc.).
Mask Types
- Causal: Lower-triangular — each position attends only to itself and earlier positions
- Window: Causal + limited lookback window
- Block diagonal: For chunked/ring attention patterns
All masks are boolean tensors where true = attend, false = mask out.
Usage
mask = CausalMask.causal(64)
window = CausalMask.window(64, 16)
# Copy to BinaryBackend for use inside Axon.nx closures
mask = CausalMask.to_binary_backend(mask)
Summary
Functions
Create a causal (autoregressive) attention mask.
Copy a mask tensor to Nx.BinaryBackend.
Create a sliding window attention mask.
Functions
@spec causal(non_neg_integer()) :: Nx.Tensor.t()
Create a causal (autoregressive) attention mask.
Returns a boolean tensor of shape [seq_len, seq_len] where position i
can attend to positions 0..i.
Examples
iex> mask = Edifice.Blocks.CausalMask.causal(4)
iex> Nx.shape(mask)
{4, 4}
@spec to_binary_backend(Nx.Tensor.t()) :: Nx.Tensor.t()
Copy a mask tensor to Nx.BinaryBackend.
Required when capturing masks in Axon.nx closures to avoid
EXLA/Defn.Expr backend mismatch during JIT compilation.
@spec window(non_neg_integer(), non_neg_integer()) :: Nx.Tensor.t()
Create a sliding window attention mask.
Each position attends to at most window_size preceding positions
(including itself). Combines causal constraint with window constraint.
Examples
iex> mask = Edifice.Blocks.CausalMask.window(8, 3)
iex> Nx.shape(mask)
{8, 8}