Single-kernel Metal shaders from mlx::fast, exposed as deftransform
functions backed by Nx.runtime_call.
Every function is defn-safe: call inside defn, Nx.Defn.jit, or from
Axon.rewrite_nodes/2 rewrite callbacks without restriction.
Functions
rms_norm/3— fused RMS normalisationlayer_norm/4— fused layer normalisation (with bias)layer_norm/3— fused layer normalisation (weight-only, no bias)rope/6— fused RoPE with scalar integer offsetrope_with_positions/6— fused RoPE accepting aposition_idstensorrope_with_freqs/6— fused RoPE with precomputed inv-frequency tensor (for:llama3scaling)scaled_dot_product_attention/4— flash-attention SDPA (no mask)scaled_dot_product_attention/5— flash-attention SDPA (additive/bool mask)scaled_dot_product_attention_causal/4— flash-attention SDPA with built-in causal maskscaled_dot_product_attention_causal_key_masked/5— causal SDPA; checks key_mask at C++ level, fast-paths to pure causal when all-onesswiglu/2— fused SwiGLU:silu(gate) * up
Axon graph rewrite example
Axon.rewrite_nodes(model, fn
%Axon.Node{op: :rms_norm, opts: [eps: eps]} ->
fn [x, weight], _output -> EMLX.Fast.rms_norm(x, weight, eps) end
_ -> :skip
end)
Summary
Functions
Fused layer normalisation without bias (mlx::fast::layer_norm, weight-only variant).
Fused layer normalisation (mlx::fast::layer_norm).
Fused RMS normalisation (mlx::fast::rms_norm).
Fused rotary position embedding (mlx::fast::rope).
Fused RoPE with precomputed inverse-frequency vector (mlx::fast::rope, freqs overload).
Fused RoPE accepting a position_ids tensor (mlx::fast::rope, array-offset overload).
Flash-attention SDPA, no mask (mlx::fast::scaled_dot_product_attention).
Flash-attention SDPA with an additive or boolean mask.
Flash-attention SDPA with a built-in causal mask (mlx::fast::scaled_dot_product_attention,
mask_mode="causal").
Causal SDPA with the key_mask check delegated to the C++ NIF.
Fused SwiGLU activation: silu(gate) * up where silu(x) = x * sigmoid(x).
Functions
Fused layer normalisation without bias (mlx::fast::layer_norm, weight-only variant).
x— input tensor; normalised over the last axis.weight—{hidden}scale vector (gamma).eps— numerical stability constant (e.g.1.0e-5).
Output shape and type match x.
Fused layer normalisation (mlx::fast::layer_norm).
x— input tensor; normalised over the last axis.weight—{hidden}scale vector (gamma).bias—{hidden}bias vector (beta).eps— numerical stability constant (e.g.1.0e-5).
Output shape and type match x.
Fused RMS normalisation (mlx::fast::rms_norm).
x— input tensor; normalised over the last axis.weight—{hidden}scale vector (same size as last axis ofx).eps— numerical stability constant (e.g.1.0e-6).
Output shape and type match x.
Fused rotary position embedding (mlx::fast::rope).
a— input{B, ..., T, D};...dims are passed through.dims— number of feature dims to rotate (≤ last-axis size, must be even).traditional—falsefor split-half (Qwen3);truefor interleaved.base— angular frequency base (e.g.10_000or1_000_000).scale— position scale (1.0unless using NTK-aware scaling).offset— integer position offset (tokens already in the KV cache).
traditional must match the model checkpoint's convention.
For Qwen3 (split-half): traditional: false.
Output shape and type match a.
Fused RoPE with precomputed inverse-frequency vector (mlx::fast::rope, freqs overload).
Use this variant when the model's RoPE scaling strategy produces a fixed
{dims/2} inv-frequency tensor that can be baked at graph-rewrite time
(e.g. :llama3 smooth-interpolation). Strategies that are seq-len conditional
or require cos/sin post-multiply (:linear, :dynamic, :longrope) should
use rope_with_positions/6 instead.
a— input{B, T, ..., D}(Bumblebee convention: heads NOT yet transposed)position_ids—{B, T}integer tensor. For decode (T = 1) the fast path usesposition_ids[b,0]as the per-batch offset intofreqs(same contract asmlx::fast::ropewith a scalar offset per batch). For prefill (T > 1) a per-token Nx path runs so arbitrary positions (e.g. left-padded[0,…,0,1,2,…]) are correct; the NIF’s offset-only entry point cannot represent that.dims— number of feature dims to rotate.traditional—falsefor split-half (Bumblebee / Qwen3);truefor interleaved.scale— position scale (1.0for most strategies with precomputed freqs).freqs—{dims/2}tensor of precomputed inverse frequencies.
Output shape and type match a.
Fused RoPE accepting a position_ids tensor (mlx::fast::rope, array-offset overload).
Use this variant when the calling convention provides position_ids as a tensor
(e.g. from Bumblebee's rotary embedding layer) rather than a scalar integer offset.
a— input{B, T, ..., D}(Bumblebee convention: heads NOT yet transposed)position_ids—{B, T}integer tensor; each row holds the token positions forone batch example. **Positions must be sequential within each row** (standard causal LM). The starting offset for batch item `b` is taken as `position_ids[b, 0]`; subsequent positions are inferred by MLX as `offset + 0, offset + 1, ...`.dims— number of feature dims to rotate.traditional—falsefor split-half (Bumblebee / Qwen3);truefor interleaved.base— angular frequency base (e.g.10_000).scale— position scale (1.0unless using NTK-aware scaling).
Output shape and type match a.
Sequential positions only (fast T=1 path)
For decode with T = 1 and base below about 1.0e5, the fast_rope_ids NIF
is used; it assumes sequential positions from position_ids[b, 0]. For larger
base (e.g. Qwen3 rope_theta 1M) or prefill (T > 1), the Nx per-token
path is used, matching Bumblebee for arbitrary per-token position_ids.
Flash-attention SDPA, no mask (mlx::fast::scaled_dot_product_attention).
GQA-native: k/v may have fewer heads than q — no pre-tiling required.
q—{B, N_q, T_q, D}k—{B, N_kv, T_kv, D}v—{B, N_kv, T_kv, D}scale— scalar (typically1 / sqrt(D))
Output: {B, N_q, T_q, D} — same dtype as q.
Softmax accumulates in float32 internally regardless of input dtype.
Flash-attention SDPA with an additive or boolean mask.
mask must be broadcast-compatible with {B, N_q, T_q, T_kv}.
Boolean false entries are masked out (-∞); float entries are added to
the pre-softmax scores.
For causal masking in decode (single query token), prefer the no-mask arity
since T_q=1 is always trivially causal.
Flash-attention SDPA with a built-in causal mask (mlx::fast::scaled_dot_product_attention,
mask_mode="causal").
MLX constructs the upper-triangular causal mask internally without materialising it,
making this equivalent to scaled_dot_product_attention/5 with a causal boolean mask
but cheaper: no mask tensor allocation, and the mask is fused into the Metal kernel.
GQA-native: k/v may have fewer heads than q — no pre-tiling required.
Input/output layout matches scaled_dot_product_attention/4:
q—{B, N_q, T_q, D}k—{B, N_kv, T_kv, D}v—{B, N_kv, T_kv, D}scale— pre-computed scalar (typically1 / sqrt(D))- Output —
{B, N_q, T_q, D}, same dtype asq
Causal SDPA with the key_mask check delegated to the C++ NIF.
At runtime the NIF evaluates all(key_mask == 1):
- true (no padding, e.g. single-sequence decode) → pure causal SDPA, no mask tensor allocated.
- false (padded batch or multi-sequence) → builds a combined causal + key_mask additive mask and calls the masked SDPA kernel.
This avoids the Nx.cond double-evaluation problem: the NIF forces eval
of only the small {B, T_kv} key_mask subgraph, then branches in C++.
Input/output layout matches scaled_dot_product_attention_causal/4:
q—{B, N_q, T_q, D}k—{B, N_kv, T_kv, D}v—{B, N_kv, T_kv, D}scale— pre-computed scalarkey_mask—{B, T_kv}boolean/int tensor (1 = attend, 0 = masked)- Output —
{B, N_q, T_q, D}, same dtype asq
Fused SwiGLU activation: silu(gate) * up where silu(x) = x * sigmoid(x).
Eliminates the two-op silu(gate_proj) * up_proj pattern that appears in
Qwen3's FFN layers (28× per decode step).
gate— gate-projection output; silu is applied element-wise.up— up-projection output; same shape asgate.
Output has the same shape and dtype as gate.