Axon model rewrites that swap supported nodes to EMLX.Fast Metal shaders.
Pass an %Axon{} model through rewrite/1 before compiling it with
Axon.build/2 or Bumblebee.Text.generation/4 to replace supported
normalization and attention nodes with single-kernel MLX equivalents.
Supported rewrites
| Key | Matched node | Replaced with |
|---|---|---|
:rms_norm | op_name: :rms_norm, shift: 0.0 | EMLX.Fast.rms_norm/3 |
:layer_norm | op_name: :layer_norm | EMLX.Fast.layer_norm/3,4 |
:rotary_embedding | Bumblebee apply_rotary_embedding/5 | EMLX.Fast.rope_with_positions/6 |
:sdpa | Bumblebee attention_output_impl/3 | EMLX.Fast.scaled_dot_product_attention_causal/4 or unmasked |
:dropout | op_name: :dropout (inference) | identity pass-through |
:swiglu | :multiply(container(up, silu(gate))) | EMLX.Fast.swiglu/2 |
:native_attention | Bumblebee causal self-attention | EMLX.kv_cache_attention_masked/8 |
Usage
{:ok, %{model: model, params: params}} = Bumblebee.load_model({:hf, "Qwen/Qwen3-0.6B"})
model = EMLXAxon.rewrite(model)
serving = Bumblebee.Text.generation(
%{model: model, params: params, spec: spec},
tokenizer, generation_config,
compile: [batch_size: 1, sequence_length: 256]
)Limitations
:rms_normrewrite requiresshift: 0.0. Nodes with a non-zero shift are skipped becauseEMLX.Fast.rms_norm(x, w, eps)computesx/rms(x)*w, notx/rms(x)*(shift+w).:rotary_embeddingrewrite assumes sequential position IDs within each batch example (standard causal LM). Non-sequential schemes (packed sequences, custom position offsets) will produce incorrect results. Bumblebee'sapply_rotary_embedding/5is matched by function identity viafunction_info/1— this is tied to Bumblebee's internal implementation and may break across major Bumblebee version changes.RoPE scaling strategies:
:llama3precomputes the inv-frequency tensor at rewrite time and dispatches toEMLX.Fast.rope_with_freqs/6. Other strategies (:linear,:dynamic,:longrope) fall back torope_with_positionswith the standard base frequency — they are not frequency-precomputable because they are seq-len conditional or require post-multiply of cos/sin thatmlx::fast::ropecannot absorb.:sdparewrite threadskey_mask(from input padding) through to the C++ NIF, which checks at runtime whether the mask is all-ones and fast-paths to the pure causal Metal kernel when it is. Padded batches get a combined causal + key_mask additive mask. Sliding-window attention falls back to the originalattention_output_impl. Inference-only: dropout is elided.:dropoutrewrite replacesop_name: :dropoutnodes with an identity pass-through. Dropout at inference time is always a no-op regardless of rate; eliminating the NIF-boundary crossings per decode step without any functional change. Not appropriate for training graphs.:swiglurewrite matches:multiplynodes backed by a:containernode whose two parents include one:silunode (the Bumblebee SwiGLU pattern:multiply(container(up_proj, silu(gate_proj)))). Replaces the multiply + container + silu triple with a singleEMLX.Fast.swiglu/2call. Does not match generic multiplications or containers without a silu child.:attn_weightsrewrite replaces:bb_attn_weightspassthrough nodes (added by the local Bumblebee patch) with a no-arg constant-zero layer, cutting theattention_weights_implsub-graph and its K-siderepeat_interleavenodes out of the reachable graph entirely. Inference-only: the attention weights tensor is never used for token generation.:if_presentrewrite replaces Bumblebee's KV-cache conditional nodes with their "cache present" branch. In compiled serving the KV cache is always initialized (never%Axon.None{}), so the else branch is dead code. Removing the:if_presentnodes and their:optionalwrappers eliminates per-step Axon dispatch overhead without any functional change.:gqa_cache_fixrewrite fixes a shape mismatch that arises when GQA head expansion (repeat_interleave) runs beforeupdate_attention_cache. The standard Bumblebee transformer block expands keys/values fromnum_key_value_headstonum_attention_headsbefore the cache update, but the cache is allocated withnum_key_value_heads. This rewrite strips therepeat_interleavefrom the key and value inputs to everyupdate_attention_cacheAxon layer, so the cache receives the compact GQA tensors. The SDPA rewriter (maybe_strip_repeat_interleave) already handles the expanded-head removal on the SDPA side, and MLX fast SDPA handles GQA natively.
Summary
Functions
Returns the rewriter function for Bumblebee aux attention-weights nodes.
Returns the rewriter function for dropout nodes.
Extracts {module, name, arity} from a function reference, or returns nil
for non-function values.
Returns the rewriter function for GQA key/value cache shape fix.
Returns the rewriter function for :if_present nodes.
Returns the rewriter function for layer_norm nodes.
Loads a quantized Bumblebee model from an MLX-4bit checkpoint directory.
Returns the rewriter function for Bumblebee causal self-attention nodes.
Returns the rewriter function for Bumblebee block-cache update nodes.
Rewrites all supported nodes in model to their EMLX.Fast equivalents.
Returns the rewriter function for rms_norm nodes.
Returns the rewriter function for Bumblebee's rotary_embedding nodes.
Returns the rewriter function for Bumblebee's attention output nodes.
Returns the rewriter function for SwiGLU nodes.
Functions
Returns the rewriter function for Bumblebee aux attention-weights nodes.
Matches :bb_attn_weights nodes — a passthrough layer inserted by the local
Bumblebee patch around the {output, weights} return of Layers.attention/8.
Replaces the node with a no-arg constant-zero layer so the entire
attention_weights_impl sub-graph (and the K-side repeat_interleave it
consumes) becomes unreachable. Inference-only: attention weight tensors are
never used for token generation.
Returns the rewriter function for dropout nodes.
At inference time, dropout is always a pass-through regardless of rate. This
rewriter replaces every :dropout node with an identity layer, eliminating the
NIF-boundary crossing without any functional change.
Not appropriate for training graphs — only enable this rewriter when the model will be used for inference only.
@spec function_info(term()) :: {module(), atom(), non_neg_integer()} | nil
Extracts {module, name, arity} from a function reference, or returns nil
for non-function values.
Works for both named functions (def/defp/defn/defnp) and closures.
Closures report the module where they were defined and a generated name like
"-foo/2-fun-0-", which is distinct from any hand-written function name and
therefore safe to use in MFA comparisons.
Note: Nx's defnp may compile to a closure rather than a named function,
so this helper intentionally does not filter by :erlang.fun_info(:type).
Returns the rewriter function for GQA key/value cache shape fix.
In the standard Bumblebee transformer block, GQA head expansion via
repeat_interleave (expanding from num_key_value_heads to num_attention_heads)
is applied to the key and value tensors before update_attention_cache. However,
init_cache allocates the preallocated buffer with num_key_value_heads, causing
a shape mismatch at Axon compile time when num_key_value_heads < num_attention_heads.
This rewriter fixes the graph by stripping the repeat_interleave node from the
key and value inputs of every update_attention_cache layer. The cache update then
operates on the compact GQA tensors. The SDPA rewriter (maybe_strip_repeat_interleave)
separately handles the expanded-head removal on the SDPA path, and MLX fast SDPA
handles GQA natively without explicit head repetition.
Only applies when the key or value parent is a Bumblebee repeat_interleave node.
Models without GQA (or where repeat_interleave is already absent) are unaffected.
Returns the rewriter function for :if_present nodes.
Bumblebee wraps every KV-cache operation in Layers.if_present(cache, ...) to
handle the case where no cache is provided. In compiled serving the cache is
always initialized (never %Axon.None{}), so the conditional is dead code.
The rewriter unconditionally selects the "cache present" branch (on_true) and
lets the "no cache" branch (on_false) and all its :optional wrappers become
unreachable, pruning them from the compiled graph.
Do not enable for training graphs — training models typically run without a
KV cache and rely on the else branch.
Returns the rewriter function for layer_norm nodes.
Replaces op_name: :layer_norm nodes (Axon's built-in layer normalisation)
with an Axon layer that calls EMLX.Fast.layer_norm/3,4 — a single fused
Metal shader. Skips nodes where channel_index is not -1 (last axis),
as the kernel only normalises over the last axis.
Loads a quantized Bumblebee model from an MLX-4bit checkpoint directory.
Combines three steps into one call:
- Loads the Axon model structure from
config.jsonviaBumblebee.load_model/2. - Loads the MLX-4bit safetensors weights via
EMLXAxon.MLX4BitParams.load/1, dequantizing and transposing to Bumblebee{in, out}layout (BF16). - Re-quantizes all eligible weight matrices via
EMLXAxon.QuantizeParams.quantize/1so thatNx.dotdispatch routes toEMLX.quantized_matmulat serving time.
Returns {:ok, model_info} compatible with Bumblebee.Text.generation/4.
Usage
{:ok, model_info} = EMLXAxon.load_quantized({:local, "~/models/Qwen3-0.6B-MLX-4bit"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:local, path})
{:ok, gen_cfg} = Bumblebee.load_generation_config({:local, path})
gen_cfg = Bumblebee.configure(gen_cfg, max_new_tokens: 100)
serving = Bumblebee.Text.generation(model_info, tokenizer, gen_cfg,
compile: [batch_size: 1, sequence_length: 256],
defn_options: [compiler: EMLX]
)
result = Nx.Serving.run(serving, "The capital of France is")Notes
Do not apply
EMLXAxon.rewrite/2afterload_quantized— the rotary embedding rewrite is incompatible with the standard Bumblebeenative_kv_cache: falsepath and produces incorrect outputs. BF16 fast ops (rms_norm, swiglu, dropout, sdpa) may be added once the rotary embedding rewrite is fixed.Model architecture is inferred from
config.jsonin the checkpoint directory. Validated with Bumblebee and Qwen3-0.6B.Quantization metadata: QuantizeParams logs shape-mismatch warnings for tensors whose physical packed dimensions differ from the Bumblebee model's expected shapes. These warnings are benign — the quantized tensors are still used correctly via the EMLX backend's quantized_matmul dispatch.
Returns the rewriter function for Bumblebee causal self-attention nodes.
This rewrite replaces the attention output with a single Nx.runtime_call
callback that updates a process-local ETS K/V cache and calls
EMLX.kv_cache_attention_masked/8. It intentionally only matches causal
attention without sliding-window masking; cross-attention and local attention
fall back to the original graph.
@spec nullify_block_cache_rewriter() :: (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
Returns the rewriter function for Bumblebee block-cache update nodes.
When native attention owns K/V state in an ETS table, the Axon block-cache
update chain is dead. Replacing put_block_cache with an identity lets DCE
prune get_block_cache, update_attention_cache, and container plumbing.
Rewrites all supported nodes in model to their EMLX.Fast equivalents.
Options
:only— list of atoms selecting which rewrites to apply. Defaults to[:rms_norm, :layer_norm, :rotary_embedding, :sdpa, :dropout, :swiglu, :attn_weights, :if_present, :native_attention, :nullify_block_cache]. Pass:gqa_cache_fixexplicitly when targeting a Bumblebee build whoseinit_cacheallocates the KV cache withnum_key_value_headsrather thannum_attention_heads(i.e. the upstream PR branch patch).
Example
model = EMLXAxon.rewrite(model)
model = EMLXAxon.rewrite(model, only: [:rms_norm, :layer_norm])
Returns the rewriter function for rms_norm nodes.
Replaces op_name: :rms_norm nodes with shift: 0.0 with an Axon layer
that calls EMLX.Fast.rms_norm/3 — a single fused Metal shader.
@spec rotary_embedding_rewriter(reference() | nil) :: (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)
Returns the rewriter function for Bumblebee's rotary_embedding nodes.
Matches %Axon.Node{op: &Bumblebee.Layers.apply_rotary_embedding/5} by MFA
identity via function_info/1, then replaces it with an EMLX.Fast.rope_with_positions/6
call on both Q and K. The replacement node returns {q_rotated, k_rotated} —
downstream Axon.nx(_, &elem(&1, i)) unwrap nodes continue to work unchanged.
Assumes sequential positions — see EMLXAxon moduledoc for the limitation.
Returns the rewriter function for Bumblebee's attention output nodes.
Matches %Axon.Node{op: &Bumblebee.Layers.attention_output_impl/3} by MFA
identity via function_info/1, then navigates up through the dropout and
attention_weights_impl nodes to recover Q and K, and replaces the whole
attention chain with a single EMLX.Fast SDPA call.
causal: true, no window_size — usesscaled_dot_product_attention_causal_key_masked/5. Thekey_maskis threaded through; the C++ NIF checks if it is all-ones at runtime and dispatches to the pure causal Metal kernel (no mask allocation) or builds a combined additive mask for padded batches.causal: false, no window_size — usesscaled_dot_product_attention/4(unmasked, for cross-attention or prefix LM heads).window_sizeset — re-applies the originalattention_output_implunchanged.
Inference-only: attention dropout is elided (a no-op at inference time). Nodes
with dropout_rate > 0 are skipped to preserve training-time stochastic behaviour.
Returns the rewriter function for SwiGLU nodes.
Matches :multiply nodes backed by a single :container parent whose two
children include one :silu node (the Bumblebee SwiGLU pattern:
multiply(container(up_proj, silu(gate_proj)))). Replaces the
multiply + container + silu triple with a single EMLX.Fast.swiglu/2 call,
passing the gate's raw input (pre-silu) and the up-projection directly to the
fused NIF.
Generic :multiply nodes (no :container parent, or container without a
:silu child) are reconstructed identically.