Root Mean Square Layer Normalization.
Simpler and faster than standard LayerNorm -- normalizes by the RMS of the activations without centering (no mean subtraction). Used by LLaMA, Mamba-2, Mistral, and most modern transformer variants.
Formula
RMSNorm(x) = x / sqrt(mean(x^2) + eps) * gammaCompared to LayerNorm which computes both mean and variance, RMSNorm only computes the RMS, saving ~50% of the normalization compute.
Usage
# As an Axon layer
normalized = RMSNorm.layer(input, hidden_size: 256)References
- "Root Mean Square Layer Normalization" (Zhang & Sennrich, 2019)
- https://arxiv.org/abs/1910.07467
Summary
Functions
Compute RMSNorm on a raw tensor.
Parameters
x- Input tensor [..., hidden_size]gamma- Learnable scale [hidden_size]
Build an RMSNorm Axon layer.
Options
:hidden_size- Feature dimension for the learnable scale (required):epsilon- Numerical stability constant (default: 1.0e-6):name- Layer name prefix (default: "rms_norm")