Scalable-Softmax (SSMax): sequence-length-aware softmax.
SSMax adjusts softmax temperature based on sequence length to maintain consistent attention sharpness across different context sizes:
SSMax(x)_i = exp(x_i - s*log(n)) / sum_j(exp(x_j - s*log(n)))Where:
nis the sequence lengthsis a learnable scalar (default initialization: 1.0)
Key Innovation
Standard softmax becomes increasingly uniform as sequence length grows (more tokens to distribute attention over). SSMax learns to compensate:
s > 0: Sharper attention for longer sequencess = 0: Standard softmax behaviors < 0: Softer attention for longer sequences
Usage as Drop-in Softmax Replacement
# In attention computation
scores = Nx.dot(q, Nx.transpose(k))
scores = Nx.divide(scores, scale)
attn_weights = SSMax.compute(scores, s_param, seq_len)Usage in Axon Model
model = SSMax.build(embed_dim: 256, hidden_size: 256)Reference
- "Scalable-Softmax Is All You Need" (2025)
Summary
Functions
Build a transformer model using SSMax instead of standard softmax.
Build attention layer using SSMax instead of softmax.
Apply SSMax to logits tensor.
Create an SSMax Axon layer that learns the scaling parameter.
Get the output dimension for a model configuration.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a transformer model using SSMax instead of standard softmax.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):num_layers- Number of transformer blocks (default: 6):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build attention layer using SSMax instead of softmax.
@spec compute(Nx.Tensor.t(), Nx.Tensor.t() | float(), pos_integer()) :: Nx.Tensor.t()
Apply SSMax to logits tensor.
Parameters
logits- Attention scores [batch, ..., seq_len]s- Learnable scaling parameter (scalar)seq_len- Sequence length (integer)
Returns
Normalized attention weights with same shape as logits.
Create an SSMax Axon layer that learns the scaling parameter.
Options
:name- Layer name prefix (default: "ssmax"):init_s- Initial value for s parameter (default: 1.0)
Returns
An Axon layer that applies SSMax to input logits.
@spec output_size(keyword()) :: non_neg_integer()
Get the output dimension for a model configuration.