# `EMLXAxon.Qwen3.Model`
[🔗](https://github.com/elixir-nx/emlx/blob/v0.3.0/emlx_axon/lib/emlx_axon/qwen3/model.ex#L1)

Qwen3 quantized model state struct and forward pass.

## Defn / JIT strategy

Every hot-path computation that consists purely of tensor arithmetic is
wrapped in a `defnp` kernel (declared in `Layers` and `Attention`).
`defnp` uses `Nx.Defn.Compiler.__jit__` — the same mechanism as
`Nx.Defn.jit/1` — to compile the function once per unique input shape and
cache the result.  Subsequent calls skip Nx-side type/shape inference and
dispatch directly to the compiled kernel.

Functions that still run eagerly:
- The quantized `Nx.dot` projections — the `Nx.Defn.Evaluator` (EMLX's
  default compiler) cannot mix jit-argument `Nx.Defn.Expr` nodes with
  captured `EMLX.Backend` tensors in closures.  Wrapping individual ops
  with `Nx.Defn.jit` would require a backend that implements
  `Nx.Defn.Compiler` (e.g. EXLA).
- `Nx.put_slice` KV-cache update (dynamic start index) and the valid-slice
  read (dynamic end index).

GPU sync: `EMLX.eval` is called once per token at the sampler boundary so
the full lazy MLX graph spans all 28 layers before any CPU sync.

# `forward`

```elixir
@spec forward(
  Nx.Tensor.t(),
  [{Nx.Tensor.t(), Nx.Tensor.t()}],
  non_neg_integer(),
  EMLXAxon.Qwen3.Model.State.t()
) :: {Nx.Tensor.t(), [{Nx.Tensor.t(), Nx.Tensor.t()}]}
```

Full forward pass for a single decode step.

- `input_ids` — `{1, seq_len}` integer tensor (the prompt or latest token)
- `kv_cache`  — list of `{k_cache, v_cache}` preallocated tensors
- `current_len` — number of tokens already written into the KV cache

Returns `{logits, kv_cache_updated}` where logits has shape `{1, vocab_size}`.
The cache is updated in-place via `Nx.put_slice`; `kv_cache_updated` is the
same list with updated slices.

# `init_kv_cache`

```elixir
@spec init_kv_cache(EMLXAxon.Qwen3.Model.State.t(), pos_integer()) :: [
  {Nx.Tensor.t(), Nx.Tensor.t()}
]
```

Initialise a preallocated KV cache for all layers.

Returns a list of `{k_cache, v_cache}` pairs, one per transformer layer,
where each cache is pre-allocated to `max_len` positions.

---

*Consult [api-reference.md](api-reference.md) for complete listing*
