EMLXAxon (emlx_axon v0.3.0)

Copy Markdown View Source

Axon model rewrites that swap supported nodes to EMLX.Fast Metal shaders.

Pass an %Axon{} model through rewrite/1 before compiling it with Axon.build/2 or Bumblebee.Text.generation/4 to replace supported normalization and attention nodes with single-kernel MLX equivalents.

Supported rewrites

KeyMatched nodeReplaced with
:rms_normop_name: :rms_norm, shift: 0.0EMLX.Fast.rms_norm/3
:layer_normop_name: :layer_normEMLX.Fast.layer_norm/3,4
:rotary_embeddingBumblebee apply_rotary_embedding/5EMLX.Fast.rope_with_positions/6
:sdpaBumblebee attention_output_impl/3EMLX.Fast.scaled_dot_product_attention_causal/4 or unmasked
:dropoutop_name: :dropout (inference)identity pass-through
:swiglu:multiply(container(up, silu(gate)))EMLX.Fast.swiglu/2
:native_attentionBumblebee causal self-attentionEMLX.kv_cache_attention_masked/8

Usage

{:ok, %{model: model, params: params}} = Bumblebee.load_model({:hf, "Qwen/Qwen3-0.6B"})
model = EMLXAxon.rewrite(model)
serving = Bumblebee.Text.generation(
  %{model: model, params: params, spec: spec},
  tokenizer, generation_config,
  compile: [batch_size: 1, sequence_length: 256]
)

Limitations

  • :rms_norm rewrite requires shift: 0.0. Nodes with a non-zero shift are skipped because EMLX.Fast.rms_norm(x, w, eps) computes x/rms(x)*w, not x/rms(x)*(shift+w).

  • :rotary_embedding rewrite assumes sequential position IDs within each batch example (standard causal LM). Non-sequential schemes (packed sequences, custom position offsets) will produce incorrect results. Bumblebee's apply_rotary_embedding/5 is matched by function identity via function_info/1 — this is tied to Bumblebee's internal implementation and may break across major Bumblebee version changes.

    RoPE scaling strategies: :llama3 precomputes the inv-frequency tensor at rewrite time and dispatches to EMLX.Fast.rope_with_freqs/6. Other strategies (:linear, :dynamic, :longrope) fall back to rope_with_positions with the standard base frequency — they are not frequency-precomputable because they are seq-len conditional or require post-multiply of cos/sin that mlx::fast::rope cannot absorb.

  • :sdpa rewrite threads key_mask (from input padding) through to the C++ NIF, which checks at runtime whether the mask is all-ones and fast-paths to the pure causal Metal kernel when it is. Padded batches get a combined causal + key_mask additive mask. Sliding-window attention falls back to the original attention_output_impl. Inference-only: dropout is elided.

  • :dropout rewrite replaces op_name: :dropout nodes with an identity pass-through. Dropout at inference time is always a no-op regardless of rate; eliminating the NIF-boundary crossings per decode step without any functional change. Not appropriate for training graphs.

  • :swiglu rewrite matches :multiply nodes backed by a :container node whose two parents include one :silu node (the Bumblebee SwiGLU pattern: multiply(container(up_proj, silu(gate_proj)))). Replaces the multiply + container + silu triple with a single EMLX.Fast.swiglu/2 call. Does not match generic multiplications or containers without a silu child.

  • :attn_weights rewrite replaces :bb_attn_weights passthrough nodes (added by the local Bumblebee patch) with a no-arg constant-zero layer, cutting the attention_weights_impl sub-graph and its K-side repeat_interleave nodes out of the reachable graph entirely. Inference-only: the attention weights tensor is never used for token generation.

  • :if_present rewrite replaces Bumblebee's KV-cache conditional nodes with their "cache present" branch. In compiled serving the KV cache is always initialized (never %Axon.None{}), so the else branch is dead code. Removing the :if_present nodes and their :optional wrappers eliminates per-step Axon dispatch overhead without any functional change.

  • :gqa_cache_fix rewrite fixes a shape mismatch that arises when GQA head expansion (repeat_interleave) runs before update_attention_cache. The standard Bumblebee transformer block expands keys/values from num_key_value_heads to num_attention_heads before the cache update, but the cache is allocated with num_key_value_heads. This rewrite strips the repeat_interleave from the key and value inputs to every update_attention_cache Axon layer, so the cache receives the compact GQA tensors. The SDPA rewriter (maybe_strip_repeat_interleave) already handles the expanded-head removal on the SDPA side, and MLX fast SDPA handles GQA natively.

Summary

Functions

Returns the rewriter function for Bumblebee aux attention-weights nodes.

Returns the rewriter function for dropout nodes.

Extracts {module, name, arity} from a function reference, or returns nil for non-function values.

Returns the rewriter function for GQA key/value cache shape fix.

Returns the rewriter function for :if_present nodes.

Returns the rewriter function for layer_norm nodes.

Loads a quantized Bumblebee model from an MLX-4bit checkpoint directory.

Returns the rewriter function for Bumblebee causal self-attention nodes.

Returns the rewriter function for Bumblebee block-cache update nodes.

Rewrites all supported nodes in model to their EMLX.Fast equivalents.

Returns the rewriter function for rms_norm nodes.

Returns the rewriter function for Bumblebee's rotary_embedding nodes.

Returns the rewriter function for Bumblebee's attention output nodes.

Returns the rewriter function for SwiGLU nodes.

Functions

attn_weights_rewriter()

@spec attn_weights_rewriter() :: (Axon.Node.t() ->
                              :skip | ([Axon.t(), ...], Axon.t() -> Axon.t()))

Returns the rewriter function for Bumblebee aux attention-weights nodes.

Matches :bb_attn_weights nodes — a passthrough layer inserted by the local Bumblebee patch around the {output, weights} return of Layers.attention/8. Replaces the node with a no-arg constant-zero layer so the entire attention_weights_impl sub-graph (and the K-side repeat_interleave it consumes) becomes unreachable. Inference-only: attention weight tensors are never used for token generation.

dropout_rewriter()

@spec dropout_rewriter() :: (Axon.Node.t() ->
                         ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for dropout nodes.

At inference time, dropout is always a pass-through regardless of rate. This rewriter replaces every :dropout node with an identity layer, eliminating the NIF-boundary crossing without any functional change.

Not appropriate for training graphs — only enable this rewriter when the model will be used for inference only.

function_info(fun)

@spec function_info(term()) :: {module(), atom(), non_neg_integer()} | nil

Extracts {module, name, arity} from a function reference, or returns nil for non-function values.

Works for both named functions (def/defp/defn/defnp) and closures. Closures report the module where they were defined and a generated name like "-foo/2-fun-0-", which is distinct from any hand-written function name and therefore safe to use in MFA comparisons.

Note: Nx's defnp may compile to a closure rather than a named function, so this helper intentionally does not filter by :erlang.fun_info(:type).

gqa_cache_fix_rewriter()

@spec gqa_cache_fix_rewriter() :: (Axon.Node.t() ->
                               ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for GQA key/value cache shape fix.

In the standard Bumblebee transformer block, GQA head expansion via repeat_interleave (expanding from num_key_value_heads to num_attention_heads) is applied to the key and value tensors before update_attention_cache. However, init_cache allocates the preallocated buffer with num_key_value_heads, causing a shape mismatch at Axon compile time when num_key_value_heads < num_attention_heads.

This rewriter fixes the graph by stripping the repeat_interleave node from the key and value inputs of every update_attention_cache layer. The cache update then operates on the compact GQA tensors. The SDPA rewriter (maybe_strip_repeat_interleave) separately handles the expanded-head removal on the SDPA path, and MLX fast SDPA handles GQA natively without explicit head repetition.

Only applies when the key or value parent is a Bumblebee repeat_interleave node. Models without GQA (or where repeat_interleave is already absent) are unaffected.

if_present_rewriter()

@spec if_present_rewriter() :: (Axon.Node.t() ->
                            ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for :if_present nodes.

Bumblebee wraps every KV-cache operation in Layers.if_present(cache, ...) to handle the case where no cache is provided. In compiled serving the cache is always initialized (never %Axon.None{}), so the conditional is dead code.

The rewriter unconditionally selects the "cache present" branch (on_true) and lets the "no cache" branch (on_false) and all its :optional wrappers become unreachable, pruning them from the compiled graph.

Do not enable for training graphs — training models typically run without a KV cache and rely on the else branch.

layer_norm_rewriter()

@spec layer_norm_rewriter() :: (Axon.Node.t() ->
                            ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for layer_norm nodes.

Replaces op_name: :layer_norm nodes (Axon's built-in layer normalisation) with an Axon layer that calls EMLX.Fast.layer_norm/3,4 — a single fused Metal shader. Skips nodes where channel_index is not -1 (last axis), as the kernel only normalises over the last axis.

load_quantized(source, opts \\ [])

@spec load_quantized({:local, Path.t()}, keyword()) :: {:ok, map()} | {:error, term()}

Loads a quantized Bumblebee model from an MLX-4bit checkpoint directory.

Combines three steps into one call:

  1. Loads the Axon model structure from config.json via Bumblebee.load_model/2.
  2. Loads the MLX-4bit safetensors weights via EMLXAxon.MLX4BitParams.load/1, dequantizing and transposing to Bumblebee {in, out} layout (BF16).
  3. Re-quantizes all eligible weight matrices via EMLXAxon.QuantizeParams.quantize/1 so that Nx.dot dispatch routes to EMLX.quantized_matmul at serving time.

Returns {:ok, model_info} compatible with Bumblebee.Text.generation/4.

Usage

{:ok, model_info} = EMLXAxon.load_quantized({:local, "~/models/Qwen3-0.6B-MLX-4bit"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:local, path})
{:ok, gen_cfg}   = Bumblebee.load_generation_config({:local, path})
gen_cfg = Bumblebee.configure(gen_cfg, max_new_tokens: 100)

serving = Bumblebee.Text.generation(model_info, tokenizer, gen_cfg,
  compile: [batch_size: 1, sequence_length: 256],
  defn_options: [compiler: EMLX]
)

result = Nx.Serving.run(serving, "The capital of France is")

Notes

  • Do not apply EMLXAxon.rewrite/2 after load_quantized — the rotary embedding rewrite is incompatible with the standard Bumblebee native_kv_cache: false path and produces incorrect outputs. BF16 fast ops (rms_norm, swiglu, dropout, sdpa) may be added once the rotary embedding rewrite is fixed.

  • Model architecture is inferred from config.json in the checkpoint directory. Validated with Bumblebee and Qwen3-0.6B.

  • Quantization metadata: QuantizeParams logs shape-mismatch warnings for tensors whose physical packed dimensions differ from the Bumblebee model's expected shapes. These warnings are benign — the quantized tensors are still used correctly via the EMLX backend's quantized_matmul dispatch.

native_attention_rewriter()

@spec native_attention_rewriter() :: (Axon.Node.t() ->
                                  ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for Bumblebee causal self-attention nodes.

This rewrite replaces the attention output with a single Nx.runtime_call callback that updates a process-local ETS K/V cache and calls EMLX.kv_cache_attention_masked/8. It intentionally only matches causal attention without sliding-window masking; cross-attention and local attention fall back to the original graph.

nullify_block_cache_rewriter()

@spec nullify_block_cache_rewriter() :: (Axon.Node.t() ->
                                     ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for Bumblebee block-cache update nodes.

When native attention owns K/V state in an ETS table, the Axon block-cache update chain is dead. Replacing put_block_cache with an identity lets DCE prune get_block_cache, update_attention_cache, and container plumbing.

rewrite(model, opts \\ [])

@spec rewrite(
  Axon.t(),
  keyword()
) :: Axon.t()

Rewrites all supported nodes in model to their EMLX.Fast equivalents.

Options

  • :only — list of atoms selecting which rewrites to apply. Defaults to [:rms_norm, :layer_norm, :rotary_embedding, :sdpa, :dropout, :swiglu, :attn_weights, :if_present, :native_attention, :nullify_block_cache]. Pass :gqa_cache_fix explicitly when targeting a Bumblebee build whose init_cache allocates the KV cache with num_key_value_heads rather than num_attention_heads (i.e. the upstream PR branch patch).

Example

model = EMLXAxon.rewrite(model)
model = EMLXAxon.rewrite(model, only: [:rms_norm, :layer_norm])

rms_norm_rewriter()

@spec rms_norm_rewriter() :: (Axon.Node.t() ->
                          ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for rms_norm nodes.

Replaces op_name: :rms_norm nodes with shift: 0.0 with an Axon layer that calls EMLX.Fast.rms_norm/3 — a single fused Metal shader.

rotary_embedding_rewriter(cache \\ nil)

@spec rotary_embedding_rewriter(reference() | nil) :: (Axon.Node.t() ->
                                                   ([Axon.t()], Axon.t() ->
                                                      Axon.t())
                                                   | :skip)

Returns the rewriter function for Bumblebee's rotary_embedding nodes.

Matches %Axon.Node{op: &Bumblebee.Layers.apply_rotary_embedding/5} by MFA identity via function_info/1, then replaces it with an EMLX.Fast.rope_with_positions/6 call on both Q and K. The replacement node returns {q_rotated, k_rotated} — downstream Axon.nx(_, &elem(&1, i)) unwrap nodes continue to work unchanged.

Assumes sequential positions — see EMLXAxon moduledoc for the limitation.

sdpa_rewriter()

@spec sdpa_rewriter() :: (Axon.Node.t() -> ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for Bumblebee's attention output nodes.

Matches %Axon.Node{op: &Bumblebee.Layers.attention_output_impl/3} by MFA identity via function_info/1, then navigates up through the dropout and attention_weights_impl nodes to recover Q and K, and replaces the whole attention chain with a single EMLX.Fast SDPA call.

  • causal: true, no window_size — uses scaled_dot_product_attention_causal_key_masked/5. The key_mask is threaded through; the C++ NIF checks if it is all-ones at runtime and dispatches to the pure causal Metal kernel (no mask allocation) or builds a combined additive mask for padded batches.
  • causal: false, no window_size — uses scaled_dot_product_attention/4 (unmasked, for cross-attention or prefix LM heads).
  • window_size set — re-applies the original attention_output_impl unchanged.

Inference-only: attention dropout is elided (a no-op at inference time). Nodes with dropout_rate > 0 are skipped to preserve training-time stochastic behaviour.

swiglu_rewriter()

@spec swiglu_rewriter() :: (Axon.Node.t() ->
                        ([Axon.t()], Axon.t() -> Axon.t()) | :skip)

Returns the rewriter function for SwiGLU nodes.

Matches :multiply nodes backed by a single :container parent whose two children include one :silu node (the Bumblebee SwiGLU pattern: multiply(container(up_proj, silu(gate_proj)))). Replaces the multiply + container + silu triple with a single EMLX.Fast.swiglu/2 call, passing the gate's raw input (pre-silu) and the up-projection directly to the fused NIF.

Generic :multiply nodes (no :container parent, or container without a :silu child) are reconstructed identically.