Edifice.Blocks.SSMax (Edifice v0.2.0)

Copy Markdown View Source

Scalable-Softmax (SSMax): sequence-length-aware softmax.

SSMax adjusts softmax temperature based on sequence length to maintain consistent attention sharpness across different context sizes:

SSMax(x)_i = exp(x_i - s*log(n)) / sum_j(exp(x_j - s*log(n)))

Where:

  • n is the sequence length
  • s is a learnable scalar (default initialization: 1.0)

Key Innovation

Standard softmax becomes increasingly uniform as sequence length grows (more tokens to distribute attention over). SSMax learns to compensate:

  • s > 0: Sharper attention for longer sequences
  • s = 0: Standard softmax behavior
  • s < 0: Softer attention for longer sequences

Usage as Drop-in Softmax Replacement

# In attention computation
scores = Nx.dot(q, Nx.transpose(k))
scores = Nx.divide(scores, scale)
attn_weights = SSMax.compute(scores, s_param, seq_len)

Usage in Axon Model

model = SSMax.build(embed_dim: 256, hidden_size: 256)

Reference

  • "Scalable-Softmax Is All You Need" (2025)

Summary

Types

Options for build/1.

Functions

Build a transformer model using SSMax instead of standard softmax.

Build attention layer using SSMax instead of softmax.

Apply SSMax to logits tensor.

Create an SSMax Axon layer that learns the scaling parameter.

Get the output dimension for a model configuration.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a transformer model using SSMax instead of standard softmax.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :num_layers - Number of transformer blocks (default: 6)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_ssmax_attention(input, opts)

@spec build_ssmax_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build attention layer using SSMax instead of softmax.

compute(logits, s, seq_len)

@spec compute(Nx.Tensor.t(), Nx.Tensor.t() | float(), pos_integer()) :: Nx.Tensor.t()

Apply SSMax to logits tensor.

Parameters

  • logits - Attention scores [batch, ..., seq_len]
  • s - Learnable scaling parameter (scalar)
  • seq_len - Sequence length (integer)

Returns

Normalized attention weights with same shape as logits.

layer(input, opts \\ [])

@spec layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Create an SSMax Axon layer that learns the scaling parameter.

Options

  • :name - Layer name prefix (default: "ssmax")
  • :init_s - Initial value for s parameter (default: 1.0)

Returns

An Axon layer that applies SSMax to input logits.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output dimension for a model configuration.