EMLXAxon.Qwen3.Attention (emlx_axon v0.3.0)

Copy Markdown View Source

Grouped-query attention (GQA) for Qwen3, with a preallocated KV cache.

KV cache layout: {1, num_kv_heads, max_len, head_dim} (heads before sequence). This layout lets us write k (already in {B, N, T, D} after transposing) into the cache without an extra transpose, and feeds directly into EMLX.Fast.scaled_dot_product_attention which expects {B, N, T, D}.

RoPE is computed by EMLX.Fast.rope/6 (single Metal shader, no precomputed cos/sin needed). The offset is the current cache fill length.

Summary

Functions

forward(hidden, k_cache, v_cache, current_len, q_proj, k_proj, v_proj, o_proj, q_norm, k_norm, cfg)

GQA forward.

Inputs:

  • hidden{1, seq_len, hidden_size} (post-norm)
  • k_cache{1, num_kv_heads, max_len, head_dim} preallocated
  • v_cache{1, num_kv_heads, max_len, head_dim} preallocated
  • current_len — number of valid positions already in the cache
  • q_proj, k_proj, v_proj, o_proj — quantized weight tensors
  • q_norm, k_norm — per-head RMSNorm weights (Qwen3 variant)
  • cfg — model config map

Returns {attn_out, k_cache_updated, v_cache_updated}.