Grouped-query attention (GQA) for Qwen3, with a preallocated KV cache.
KV cache layout: {1, num_kv_heads, max_len, head_dim} (heads before sequence).
This layout lets us write k (already in {B, N, T, D} after transposing) into
the cache without an extra transpose, and feeds directly into
EMLX.Fast.scaled_dot_product_attention which expects {B, N, T, D}.
RoPE is computed by EMLX.Fast.rope/6 (single Metal shader, no precomputed
cos/sin needed). The offset is the current cache fill length.
Summary
Functions
GQA forward.
Inputs:
hidden—{1, seq_len, hidden_size}(post-norm)k_cache—{1, num_kv_heads, max_len, head_dim}preallocatedv_cache—{1, num_kv_heads, max_len, head_dim}preallocatedcurrent_len— number of valid positions already in the cacheq_proj,k_proj,v_proj,o_proj— quantized weight tensorsq_norm,k_norm— per-head RMSNorm weights (Qwen3 variant)cfg— model config map
Returns {attn_out, k_cache_updated, v_cache_updated}.