EMLXAxon.Qwen3.Model (emlx_axon v0.3.0)

Copy Markdown View Source

Qwen3 quantized model state struct and forward pass.

Defn / JIT strategy

Every hot-path computation that consists purely of tensor arithmetic is wrapped in a defnp kernel (declared in Layers and Attention). defnp uses Nx.Defn.Compiler.__jit__ — the same mechanism as Nx.Defn.jit/1 — to compile the function once per unique input shape and cache the result. Subsequent calls skip Nx-side type/shape inference and dispatch directly to the compiled kernel.

Functions that still run eagerly:

  • The quantized Nx.dot projections — the Nx.Defn.Evaluator (EMLX's default compiler) cannot mix jit-argument Nx.Defn.Expr nodes with captured EMLX.Backend tensors in closures. Wrapping individual ops with Nx.Defn.jit would require a backend that implements Nx.Defn.Compiler (e.g. EXLA).
  • Nx.put_slice KV-cache update (dynamic start index) and the valid-slice read (dynamic end index).

GPU sync: EMLX.eval is called once per token at the sampler boundary so the full lazy MLX graph spans all 28 layers before any CPU sync.

Summary

Functions

Full forward pass for a single decode step.

Initialise a preallocated KV cache for all layers.

Functions

forward(input_ids, kv_cache, current_len, state)

Full forward pass for a single decode step.

  • input_ids{1, seq_len} integer tensor (the prompt or latest token)
  • kv_cache — list of {k_cache, v_cache} preallocated tensors
  • current_len — number of tokens already written into the KV cache

Returns {logits, kv_cache_updated} where logits has shape {1, vocab_size}. The cache is updated in-place via Nx.put_slice; kv_cache_updated is the same list with updated slices.

init_kv_cache(state, max_len)

@spec init_kv_cache(EMLXAxon.Qwen3.Model.State.t(), pos_integer()) :: [
  {Nx.Tensor.t(), Nx.Tensor.t()}
]

Initialise a preallocated KV cache for all layers.

Returns a list of {k_cache, v_cache} pairs, one per transformer layer, where each cache is pre-allocated to max_len positions.