EMLX.Fast (emlx v0.3.0)

Copy Markdown View Source

Single-kernel Metal shaders from mlx::fast, exposed as deftransform functions backed by Nx.runtime_call.

Every function is defn-safe: call inside defn, Nx.Defn.jit, or from Axon.rewrite_nodes/2 rewrite callbacks without restriction.

Functions

Axon graph rewrite example

Axon.rewrite_nodes(model, fn
  %Axon.Node{op: :rms_norm, opts: [eps: eps]} ->
    fn [x, weight], _output -> EMLX.Fast.rms_norm(x, weight, eps) end
  _ -> :skip
end)

Summary

Functions

Fused layer normalisation without bias (mlx::fast::layer_norm, weight-only variant).

Fused layer normalisation (mlx::fast::layer_norm).

Fused RMS normalisation (mlx::fast::rms_norm).

Fused rotary position embedding (mlx::fast::rope).

Fused RoPE with precomputed inverse-frequency vector (mlx::fast::rope, freqs overload).

Fused RoPE accepting a position_ids tensor (mlx::fast::rope, array-offset overload).

Flash-attention SDPA, no mask (mlx::fast::scaled_dot_product_attention).

Flash-attention SDPA with an additive or boolean mask.

Flash-attention SDPA with a built-in causal mask (mlx::fast::scaled_dot_product_attention, mask_mode="causal").

Causal SDPA with the key_mask check delegated to the C++ NIF.

Fused SwiGLU activation: silu(gate) * up where silu(x) = x * sigmoid(x).

Functions

layer_norm(x, weight, eps)

Fused layer normalisation without bias (mlx::fast::layer_norm, weight-only variant).

  • x — input tensor; normalised over the last axis.
  • weight{hidden} scale vector (gamma).
  • eps — numerical stability constant (e.g. 1.0e-5).

Output shape and type match x.

layer_norm(x, weight, bias, eps)

Fused layer normalisation (mlx::fast::layer_norm).

  • x — input tensor; normalised over the last axis.
  • weight{hidden} scale vector (gamma).
  • bias{hidden} bias vector (beta).
  • eps — numerical stability constant (e.g. 1.0e-5).

Output shape and type match x.

rms_norm(x, weight, eps)

Fused RMS normalisation (mlx::fast::rms_norm).

  • x — input tensor; normalised over the last axis.
  • weight{hidden} scale vector (same size as last axis of x).
  • eps — numerical stability constant (e.g. 1.0e-6).

Output shape and type match x.

rope(a, dims, traditional, base, scale, offset)

Fused rotary position embedding (mlx::fast::rope).

  • a — input {B, ..., T, D}; ... dims are passed through.
  • dims — number of feature dims to rotate (≤ last-axis size, must be even).
  • traditionalfalse for split-half (Qwen3); true for interleaved.
  • base — angular frequency base (e.g. 10_000 or 1_000_000).
  • scale — position scale (1.0 unless using NTK-aware scaling).
  • offset — integer position offset (tokens already in the KV cache).

traditional must match the model checkpoint's convention. For Qwen3 (split-half): traditional: false.

Output shape and type match a.

rope_with_freqs(a, position_ids, dims, traditional, scale, freqs)

Fused RoPE with precomputed inverse-frequency vector (mlx::fast::rope, freqs overload).

Use this variant when the model's RoPE scaling strategy produces a fixed {dims/2} inv-frequency tensor that can be baked at graph-rewrite time (e.g. :llama3 smooth-interpolation). Strategies that are seq-len conditional or require cos/sin post-multiply (:linear, :dynamic, :longrope) should use rope_with_positions/6 instead.

  • a — input {B, T, ..., D} (Bumblebee convention: heads NOT yet transposed)
  • position_ids{B, T} integer tensor. For decode (T = 1) the fast path uses position_ids[b,0] as the per-batch offset into freqs (same contract as mlx::fast::rope with a scalar offset per batch). For prefill (T > 1) a per-token Nx path runs so arbitrary positions (e.g. left-padded [0,…,0,1,2,…]) are correct; the NIF’s offset-only entry point cannot represent that.
  • dims — number of feature dims to rotate.
  • traditionalfalse for split-half (Bumblebee / Qwen3); true for interleaved.
  • scale — position scale (1.0 for most strategies with precomputed freqs).
  • freqs{dims/2} tensor of precomputed inverse frequencies.

Output shape and type match a.

rope_with_positions(a, position_ids, dims, traditional, base, scale)

Fused RoPE accepting a position_ids tensor (mlx::fast::rope, array-offset overload).

Use this variant when the calling convention provides position_ids as a tensor (e.g. from Bumblebee's rotary embedding layer) rather than a scalar integer offset.

  • a — input {B, T, ..., D} (Bumblebee convention: heads NOT yet transposed)
  • position_ids{B, T} integer tensor; each row holds the token positions for
                 one batch example. **Positions must be sequential within each row**
                 (standard causal LM). The starting offset for batch item `b` is
                 taken as `position_ids[b, 0]`; subsequent positions are inferred
                 by MLX as `offset + 0, offset + 1, ...`.
  • dims — number of feature dims to rotate.
  • traditionalfalse for split-half (Bumblebee / Qwen3); true for interleaved.
  • base — angular frequency base (e.g. 10_000).
  • scale — position scale (1.0 unless using NTK-aware scaling).

Output shape and type match a.

Sequential positions only (fast T=1 path)

For decode with T = 1 and base below about 1.0e5, the fast_rope_ids NIF is used; it assumes sequential positions from position_ids[b, 0]. For larger base (e.g. Qwen3 rope_theta 1M) or prefill (T > 1), the Nx per-token path is used, matching Bumblebee for arbitrary per-token position_ids.

scaled_dot_product_attention(q, k, v, scale)

Flash-attention SDPA, no mask (mlx::fast::scaled_dot_product_attention).

GQA-native: k/v may have fewer heads than q — no pre-tiling required.

  • q{B, N_q, T_q, D}
  • k{B, N_kv, T_kv, D}
  • v{B, N_kv, T_kv, D}
  • scale — scalar (typically 1 / sqrt(D))

Output: {B, N_q, T_q, D} — same dtype as q. Softmax accumulates in float32 internally regardless of input dtype.

scaled_dot_product_attention(q, k, v, scale, mask)

Flash-attention SDPA with an additive or boolean mask.

mask must be broadcast-compatible with {B, N_q, T_q, T_kv}. Boolean false entries are masked out (-∞); float entries are added to the pre-softmax scores.

For causal masking in decode (single query token), prefer the no-mask arity since T_q=1 is always trivially causal.

scaled_dot_product_attention_causal(q, k, v, scale)

Flash-attention SDPA with a built-in causal mask (mlx::fast::scaled_dot_product_attention, mask_mode="causal").

MLX constructs the upper-triangular causal mask internally without materialising it, making this equivalent to scaled_dot_product_attention/5 with a causal boolean mask but cheaper: no mask tensor allocation, and the mask is fused into the Metal kernel.

GQA-native: k/v may have fewer heads than q — no pre-tiling required.

Input/output layout matches scaled_dot_product_attention/4:

  • q{B, N_q, T_q, D}
  • k{B, N_kv, T_kv, D}
  • v{B, N_kv, T_kv, D}
  • scale — pre-computed scalar (typically 1 / sqrt(D))
  • Output — {B, N_q, T_q, D}, same dtype as q

scaled_dot_product_attention_causal_key_masked(q, k, v, scale, key_mask)

Causal SDPA with the key_mask check delegated to the C++ NIF.

At runtime the NIF evaluates all(key_mask == 1):

  • true (no padding, e.g. single-sequence decode) → pure causal SDPA, no mask tensor allocated.
  • false (padded batch or multi-sequence) → builds a combined causal + key_mask additive mask and calls the masked SDPA kernel.

This avoids the Nx.cond double-evaluation problem: the NIF forces eval of only the small {B, T_kv} key_mask subgraph, then branches in C++.

Input/output layout matches scaled_dot_product_attention_causal/4:

  • q{B, N_q, T_q, D}
  • k{B, N_kv, T_kv, D}
  • v{B, N_kv, T_kv, D}
  • scale — pre-computed scalar
  • key_mask{B, T_kv} boolean/int tensor (1 = attend, 0 = masked)
  • Output — {B, N_q, T_q, D}, same dtype as q

swiglu(gate, up)

Fused SwiGLU activation: silu(gate) * up where silu(x) = x * sigmoid(x).

Eliminates the two-op silu(gate_proj) * up_proj pattern that appears in Qwen3's FFN layers (28× per decode step).

  • gate — gate-projection output; silu is applied element-wise.
  • up — up-projection output; same shape as gate.

Output has the same shape and dtype as gate.