# `EMLXAxon`
[🔗](https://github.com/elixir-nx/emlx/blob/v0.3.0/emlx_axon/lib/emlx_axon.ex#L1)

Axon model rewrites that swap supported nodes to `EMLX.Fast` Metal shaders.

Pass an `%Axon{}` model through `rewrite/1` before compiling it with
`Axon.build/2` or `Bumblebee.Text.generation/4` to replace supported
normalization and attention nodes with single-kernel MLX equivalents.

## Supported rewrites

| Key                  | Matched node                           | Replaced with                                  |
|----------------------|----------------------------------------|------------------------------------------------|
| `:rms_norm`          | `op_name: :rms_norm`, `shift: 0.0`    | `EMLX.Fast.rms_norm/3`                        |
| `:layer_norm`        | `op_name: :layer_norm`                 | `EMLX.Fast.layer_norm/3,4`                    |
| `:rotary_embedding`  | Bumblebee `apply_rotary_embedding/5`   | `EMLX.Fast.rope_with_positions/6`             |
| `:sdpa`              | Bumblebee `attention_output_impl/3`    | `EMLX.Fast.scaled_dot_product_attention_causal/4` or unmasked |
| `:dropout`           | `op_name: :dropout` (inference)        | identity pass-through                          |
| `:swiglu`            | `:multiply(container(up, silu(gate)))` | `EMLX.Fast.swiglu/2`                          |
| `:native_attention`  | Bumblebee causal self-attention        | `EMLX.kv_cache_attention_masked/8`            |

## Usage

    {:ok, %{model: model, params: params}} = Bumblebee.load_model({:hf, "Qwen/Qwen3-0.6B"})
    model = EMLXAxon.rewrite(model)
    serving = Bumblebee.Text.generation(
      %{model: model, params: params, spec: spec},
      tokenizer, generation_config,
      compile: [batch_size: 1, sequence_length: 256]
    )

## Limitations

- **`:rms_norm`** rewrite requires `shift: 0.0`. Nodes with a non-zero shift
  are skipped because `EMLX.Fast.rms_norm(x, w, eps)` computes `x/rms(x)*w`,
  not `x/rms(x)*(shift+w)`.

- **`:rotary_embedding`** rewrite assumes sequential position IDs within each
  batch example (standard causal LM). Non-sequential schemes (packed sequences,
  custom position offsets) will produce incorrect results. Bumblebee's
  `apply_rotary_embedding/5` is matched by function identity via `function_info/1` —
  this is tied to Bumblebee's internal implementation and may break across major
  Bumblebee version changes.

  RoPE scaling strategies: `:llama3` precomputes the inv-frequency tensor at
  rewrite time and dispatches to `EMLX.Fast.rope_with_freqs/6`. Other strategies
  (`:linear`, `:dynamic`, `:longrope`) fall back to `rope_with_positions` with the
  standard base frequency — they are not frequency-precomputable because they are
  seq-len conditional or require post-multiply of cos/sin that `mlx::fast::rope`
  cannot absorb.

- **`:sdpa`** rewrite threads `key_mask` (from input padding) through to the
  C++ NIF, which checks at runtime whether the mask is all-ones and fast-paths
  to the pure causal Metal kernel when it is. Padded batches get a combined
  causal + key_mask additive mask. Sliding-window attention falls back to the
  original `attention_output_impl`. Inference-only: dropout is elided.

- **`:dropout`** rewrite replaces `op_name: :dropout` nodes with an identity
  pass-through. Dropout at inference time is always a no-op regardless of rate;
  eliminating the NIF-boundary crossings per decode step without any functional
  change. Not appropriate for training graphs.

- **`:swiglu`** rewrite matches `:multiply` nodes backed by a `:container` node whose
  two parents include one `:silu` node (the Bumblebee SwiGLU pattern:
  `multiply(container(up_proj, silu(gate_proj)))`). Replaces the multiply + container +
  silu triple with a single `EMLX.Fast.swiglu/2` call. Does not match generic
  multiplications or containers without a silu child.

- **`:attn_weights`** rewrite replaces `:bb_attn_weights` passthrough nodes (added by
  the local Bumblebee patch) with a no-arg constant-zero layer, cutting the
  `attention_weights_impl` sub-graph and its K-side `repeat_interleave` nodes out of the
  reachable graph entirely. Inference-only: the attention weights tensor is never used
  for token generation.

- **`:if_present`** rewrite replaces Bumblebee's KV-cache conditional nodes with their
  "cache present" branch. In compiled serving the KV cache is always initialized (never
  `%Axon.None{}`), so the else branch is dead code. Removing the `:if_present` nodes
  and their `:optional` wrappers eliminates per-step Axon dispatch overhead without any
  functional change.

- **`:gqa_cache_fix`** rewrite fixes a shape mismatch that arises when GQA head expansion
  (`repeat_interleave`) runs *before* `update_attention_cache`. The standard Bumblebee
  transformer block expands keys/values from `num_key_value_heads` to `num_attention_heads`
  before the cache update, but the cache is allocated with `num_key_value_heads`. This
  rewrite strips the `repeat_interleave` from the key and value inputs to every
  `update_attention_cache` Axon layer, so the cache receives the compact GQA tensors.
  The SDPA rewriter (`maybe_strip_repeat_interleave`) already handles the expanded-head
  removal on the SDPA side, and MLX fast SDPA handles GQA natively.

# `attn_weights_rewriter`

```elixir
@spec attn_weights_rewriter() :: (Axon.Node.t() -&gt;
                              :skip | ([Axon.t(), ...], Axon.t() -&gt; Axon.t()))
```

Returns the rewriter function for Bumblebee aux attention-weights nodes.

Matches `:bb_attn_weights` nodes — a passthrough layer inserted by the local
Bumblebee patch around the `{output, weights}` return of `Layers.attention/8`.
Replaces the node with a no-arg constant-zero layer so the entire
`attention_weights_impl` sub-graph (and the K-side `repeat_interleave` it
consumes) becomes unreachable. Inference-only: attention weight tensors are
never used for token generation.

# `dropout_rewriter`

```elixir
@spec dropout_rewriter() :: (Axon.Node.t() -&gt;
                         ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for `dropout` nodes.

At inference time, dropout is always a pass-through regardless of rate. This
rewriter replaces every `:dropout` node with an identity layer, eliminating the
NIF-boundary crossing without any functional change.

**Not appropriate for training graphs** — only enable this rewriter when the
model will be used for inference only.

# `function_info`

```elixir
@spec function_info(term()) :: {module(), atom(), non_neg_integer()} | nil
```

Extracts `{module, name, arity}` from a function reference, or returns `nil`
for non-function values.

Works for both named functions (`def`/`defp`/`defn`/`defnp`) and closures.
Closures report the module where they were defined and a generated name like
`"-foo/2-fun-0-"`, which is distinct from any hand-written function name and
therefore safe to use in MFA comparisons.

Note: Nx's `defnp` may compile to a closure rather than a named function,
so this helper intentionally does not filter by `:erlang.fun_info(:type)`.

# `gqa_cache_fix_rewriter`

```elixir
@spec gqa_cache_fix_rewriter() :: (Axon.Node.t() -&gt;
                               ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for GQA key/value cache shape fix.

In the standard Bumblebee transformer block, GQA head expansion via
`repeat_interleave` (expanding from `num_key_value_heads` to `num_attention_heads`)
is applied to the key and value tensors *before* `update_attention_cache`. However,
`init_cache` allocates the preallocated buffer with `num_key_value_heads`, causing
a shape mismatch at Axon compile time when `num_key_value_heads < num_attention_heads`.

This rewriter fixes the graph by stripping the `repeat_interleave` node from the
key and value inputs of every `update_attention_cache` layer. The cache update then
operates on the compact GQA tensors. The SDPA rewriter (`maybe_strip_repeat_interleave`)
separately handles the expanded-head removal on the SDPA path, and MLX fast SDPA
handles GQA natively without explicit head repetition.

Only applies when the key or value parent is a Bumblebee `repeat_interleave` node.
Models without GQA (or where repeat_interleave is already absent) are unaffected.

# `if_present_rewriter`

```elixir
@spec if_present_rewriter() :: (Axon.Node.t() -&gt;
                            ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for `:if_present` nodes.

Bumblebee wraps every KV-cache operation in `Layers.if_present(cache, ...)` to
handle the case where no cache is provided. In compiled serving the cache is
always initialized (never `%Axon.None{}`), so the conditional is dead code.

The rewriter unconditionally selects the "cache present" branch (`on_true`) and
lets the "no cache" branch (`on_false`) and all its `:optional` wrappers become
unreachable, pruning them from the compiled graph.

**Do not enable for training graphs** — training models typically run without a
KV cache and rely on the `else` branch.

# `layer_norm_rewriter`

```elixir
@spec layer_norm_rewriter() :: (Axon.Node.t() -&gt;
                            ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for `layer_norm` nodes.

Replaces `op_name: :layer_norm` nodes (Axon's built-in layer normalisation)
with an Axon layer that calls `EMLX.Fast.layer_norm/3,4` — a single fused
Metal shader. Skips nodes where `channel_index` is not `-1` (last axis),
as the kernel only normalises over the last axis.

# `load_quantized`

```elixir
@spec load_quantized({:local, Path.t()}, keyword()) :: {:ok, map()} | {:error, term()}
```

Loads a quantized Bumblebee model from an MLX-4bit checkpoint directory.

Combines three steps into one call:

1. Loads the Axon model structure from `config.json` via `Bumblebee.load_model/2`.
2. Loads the MLX-4bit safetensors weights via `EMLXAxon.MLX4BitParams.load/1`,
   dequantizing and transposing to Bumblebee `{in, out}` layout (BF16).
3. Re-quantizes all eligible weight matrices via `EMLXAxon.QuantizeParams.quantize/1`
   so that `Nx.dot` dispatch routes to `EMLX.quantized_matmul` at serving time.

Returns `{:ok, model_info}` compatible with `Bumblebee.Text.generation/4`.

## Usage

    {:ok, model_info} = EMLXAxon.load_quantized({:local, "~/models/Qwen3-0.6B-MLX-4bit"})
    {:ok, tokenizer} = Bumblebee.load_tokenizer({:local, path})
    {:ok, gen_cfg}   = Bumblebee.load_generation_config({:local, path})
    gen_cfg = Bumblebee.configure(gen_cfg, max_new_tokens: 100)

    serving = Bumblebee.Text.generation(model_info, tokenizer, gen_cfg,
      compile: [batch_size: 1, sequence_length: 256],
      defn_options: [compiler: EMLX]
    )

    result = Nx.Serving.run(serving, "The capital of France is")

## Notes

- **Do not apply `EMLXAxon.rewrite/2` after `load_quantized`** — the rotary
  embedding rewrite is incompatible with the standard Bumblebee `native_kv_cache: false`
  path and produces incorrect outputs. BF16 fast ops (rms_norm, swiglu, dropout, sdpa)
  may be added once the rotary embedding rewrite is fixed.

- Model architecture is inferred from `config.json` in the checkpoint directory.
  Validated with Bumblebee and Qwen3-0.6B.

- Quantization metadata: QuantizeParams logs shape-mismatch warnings for tensors whose
  physical packed dimensions differ from the Bumblebee model's expected shapes. These
  warnings are benign — the quantized tensors are still used correctly via the EMLX
  backend's quantized_matmul dispatch.

# `native_attention_rewriter`

```elixir
@spec native_attention_rewriter() :: (Axon.Node.t() -&gt;
                                  ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for Bumblebee causal self-attention nodes.

This rewrite replaces the attention output with a single `Nx.runtime_call`
callback that updates a process-local ETS K/V cache and calls
`EMLX.kv_cache_attention_masked/8`. It intentionally only matches causal
attention without sliding-window masking; cross-attention and local attention
fall back to the original graph.

# `nullify_block_cache_rewriter`

```elixir
@spec nullify_block_cache_rewriter() :: (Axon.Node.t() -&gt;
                                     ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for Bumblebee block-cache update nodes.

When native attention owns K/V state in an ETS table, the Axon block-cache
update chain is dead. Replacing `put_block_cache` with an identity lets DCE
prune `get_block_cache`, `update_attention_cache`, and container plumbing.

# `rewrite`

```elixir
@spec rewrite(
  Axon.t(),
  keyword()
) :: Axon.t()
```

Rewrites all supported nodes in `model` to their `EMLX.Fast` equivalents.

## Options

  * `:only` — list of atoms selecting which rewrites to apply. Defaults to
    `[:rms_norm, :layer_norm, :rotary_embedding, :sdpa, :dropout, :swiglu,
    :attn_weights, :if_present, :native_attention, :nullify_block_cache]`.
    Pass `:gqa_cache_fix` explicitly when targeting a Bumblebee build whose
    `init_cache` allocates the KV cache with `num_key_value_heads` rather than
    `num_attention_heads` (i.e. the upstream PR branch patch).

## Example

    model = EMLXAxon.rewrite(model)
    model = EMLXAxon.rewrite(model, only: [:rms_norm, :layer_norm])

# `rms_norm_rewriter`

```elixir
@spec rms_norm_rewriter() :: (Axon.Node.t() -&gt;
                          ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for `rms_norm` nodes.

Replaces `op_name: :rms_norm` nodes with `shift: 0.0` with an Axon layer
that calls `EMLX.Fast.rms_norm/3` — a single fused Metal shader.

# `rotary_embedding_rewriter`

```elixir
@spec rotary_embedding_rewriter(reference() | nil) :: (Axon.Node.t() -&gt;
                                                   ([Axon.t()], Axon.t() -&gt;
                                                      Axon.t())
                                                   | :skip)
```

Returns the rewriter function for Bumblebee's `rotary_embedding` nodes.

Matches `%Axon.Node{op: &Bumblebee.Layers.apply_rotary_embedding/5}` by MFA
identity via `function_info/1`, then replaces it with an `EMLX.Fast.rope_with_positions/6`
call on both Q and K. The replacement node returns `{q_rotated, k_rotated}` —
downstream `Axon.nx(_, &elem(&1, i))` unwrap nodes continue to work unchanged.

**Assumes sequential positions** — see `EMLXAxon` moduledoc for the limitation.

# `sdpa_rewriter`

```elixir
@spec sdpa_rewriter() :: (Axon.Node.t() -&gt; ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for Bumblebee's attention output nodes.

Matches `%Axon.Node{op: &Bumblebee.Layers.attention_output_impl/3}` by MFA
identity via `function_info/1`, then navigates up through the dropout and
`attention_weights_impl` nodes to recover Q and K, and replaces the whole
attention chain with a single `EMLX.Fast` SDPA call.

- **`causal: true`, no window_size** — uses `scaled_dot_product_attention_causal_key_masked/5`.
  The `key_mask` is threaded through; the C++ NIF checks if it is all-ones at
  runtime and dispatches to the pure causal Metal kernel (no mask allocation) or
  builds a combined additive mask for padded batches.
- **`causal: false`, no window_size** — uses `scaled_dot_product_attention/4`
  (unmasked, for cross-attention or prefix LM heads).
- **`window_size` set** — re-applies the original `attention_output_impl` unchanged.

**Inference-only**: attention dropout is elided (a no-op at inference time). Nodes
with `dropout_rate > 0` are skipped to preserve training-time stochastic behaviour.

# `swiglu_rewriter`

```elixir
@spec swiglu_rewriter() :: (Axon.Node.t() -&gt;
                        ([Axon.t()], Axon.t() -&gt; Axon.t()) | :skip)
```

Returns the rewriter function for SwiGLU nodes.

Matches `:multiply` nodes backed by a single `:container` parent whose two
children include one `:silu` node (the Bumblebee SwiGLU pattern:
`multiply(container(up_proj, silu(gate_proj)))`). Replaces the
multiply + container + silu triple with a single `EMLX.Fast.swiglu/2` call,
passing the gate's raw input (pre-silu) and the up-projection directly to the
fused NIF.

Generic `:multiply` nodes (no `:container` parent, or container without a
`:silu` child) are reconstructed identically.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
