Edifice.Attention.KDA (Edifice v0.2.0)

Copy Markdown View Source

KDA: Kimi Delta Attention.

Implements the KDA mechanism from "Kimi Linear: An Expressive, Efficient Attention Architecture" (Moonshot AI, 2025). KDA extends Gated DeltaNet with channel-wise (per-dimension) decay gating, allowing different semantic dimensions to persist at different rates.

Key Innovation: Channel-Wise Decay

Where Gated DeltaNet uses a scalar alpha per head (all channels decay at the same rate), KDA uses a vector alpha in R^{d_k} — each channel gets its own independent decay rate.

Gated DeltaNet:  S_t = alpha_t * S_{t-1} + beta_t * (v_t - S_t k_t) k_t^T
                 alpha_t is scalar per head

KDA:             S_t = (I - beta_t * k_t k_t^T) * Diag(alpha_t) * S_{t-1}
                       + beta_t * k_t v_t^T
                 alpha_t is vector per channel (d_k dims)

This means syntax cues can persist while recency signals decay rapidly, or vice versa — the model learns per-channel memory dynamics.

Architecture

Input [batch, seq_len, hidden_size]
      |
      v
[Pre-LayerNorm]
      |
+-----+------+--------+--------+--------+
|     |      |        |        |        |
Q     K      V     Alpha    Beta    Gate
|     |      |    (channel) (scalar) (output)
|     |      |        |        |
[Short Conv + SiLU]   |        |
|     |      |        |        |
[L2 Normalize Q, K]   |        |
|     |      |        |        |
+-----+------+--------+--------+
      |
[KDA Recurrence]
      |
[RMSNorm * sigmoid(Gate)]
      |
[Output Projection]
      |
[Residual Connection]
      |
      v
Output [batch, seq_len, hidden_size]

Alpha Gate Production

The channel-wise decay is produced by a low-rank MLP: alpha_t = sigmoid(W_up * SiLU(W_down * x_t))

Stored in log-space for numerical stability.

Kimi Linear Hybrid

In the full Kimi Linear model, KDA layers are interleaved with MLA (Multi-head Latent Attention) at a 3:1 ratio.

Usage

model = KDA.build(
  embed_dim: 256,
  hidden_size: 256,
  num_heads: 4,
  num_layers: 4
)

References

Summary

Types

Options for build/1.

Functions

Build a KDA model for sequence processing.

Build a single KDA block for use as a hybrid backbone layer.

Get the output size of a KDA model.

Types

build_opt()

@type build_opt() ::
  {:conv_size, pos_integer()}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:use_short_conv, boolean()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a KDA model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of independent KDA heads (default: 4)
  • :num_layers - Number of KDA layers (default: 4)
  • :dropout - Dropout rate between layers (default: 0.1)
  • :use_short_conv - Use short convolution before Q/K/V (default: true)
  • :conv_size - Short convolution kernel size (default: 4)
  • :window_size / :seq_len - Expected sequence length (default: 60)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_block(input, opts \\ [])

@spec build_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single KDA block for use as a hybrid backbone layer.

Takes [batch, seq_len, hidden_size] and returns the same shape.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a KDA model.