Qwen3 quantized model state struct and forward pass.
Defn / JIT strategy
Every hot-path computation that consists purely of tensor arithmetic is
wrapped in a defnp kernel (declared in Layers and Attention).
defnp uses Nx.Defn.Compiler.__jit__ — the same mechanism as
Nx.Defn.jit/1 — to compile the function once per unique input shape and
cache the result. Subsequent calls skip Nx-side type/shape inference and
dispatch directly to the compiled kernel.
Functions that still run eagerly:
- The quantized
Nx.dotprojections — theNx.Defn.Evaluator(EMLX's default compiler) cannot mix jit-argumentNx.Defn.Exprnodes with capturedEMLX.Backendtensors in closures. Wrapping individual ops withNx.Defn.jitwould require a backend that implementsNx.Defn.Compiler(e.g. EXLA). Nx.put_sliceKV-cache update (dynamic start index) and the valid-slice read (dynamic end index).
GPU sync: EMLX.eval is called once per token at the sampler boundary so
the full lazy MLX graph spans all 28 layers before any CPU sync.
Summary
Functions
Full forward pass for a single decode step.
Initialise a preallocated KV cache for all layers.
Functions
@spec forward( Nx.Tensor.t(), [{Nx.Tensor.t(), Nx.Tensor.t()}], non_neg_integer(), EMLXAxon.Qwen3.Model.State.t() ) :: {Nx.Tensor.t(), [{Nx.Tensor.t(), Nx.Tensor.t()}]}
Full forward pass for a single decode step.
input_ids—{1, seq_len}integer tensor (the prompt or latest token)kv_cache— list of{k_cache, v_cache}preallocated tensorscurrent_len— number of tokens already written into the KV cache
Returns {logits, kv_cache_updated} where logits has shape {1, vocab_size}.
The cache is updated in-place via Nx.put_slice; kv_cache_updated is the
same list with updated slices.
@spec init_kv_cache(EMLXAxon.Qwen3.Model.State.t(), pos_integer()) :: [ {Nx.Tensor.t(), Nx.Tensor.t()} ]
Initialise a preallocated KV cache for all layers.
Returns a list of {k_cache, v_cache} pairs, one per transformer layer,
where each cache is pre-allocated to max_len positions.