EMLXAxon.Qwen3.Layers (emlx_axon v0.3.0)

Copy Markdown View Source

Stateless layer primitives: RMSNorm, SwiGLU.

rms_norm/3 delegates to EMLX.Fast.rms_norm (single fused Metal shader). RoPE is no longer computed here — EMLX.Fast.rope/6 is called directly in Attention.forward/10 after projecting and transposing to {B, N, T, D}.

Summary

Functions

Root-mean-square layer normalisation via mlx::fast::rms_norm.

SwiGLU activation: silu(gate) * up. gate and up must have the same shape.

Functions

rms_norm(x, weight, eps)

Root-mean-square layer normalisation via mlx::fast::rms_norm.

x: any shape; normalises over the last axis. weight: {hidden} scale vector.

swiglu(gate, up)

SwiGLU activation: silu(gate) * up. gate and up must have the same shape.