Edifice.Attention.MLA (Edifice v0.2.0)

Copy Markdown View Source

Multi-Head Latent Attention (MLA) from DeepSeek-V2/V3.

MLA compresses key-value representations into low-rank latent vectors, dramatically reducing KV cache memory while maintaining attention quality. It also uses decoupled Rotary Position Embedding (RoPE) to keep position information separate from compressed content.

Key Innovations

  • KV compression: Instead of caching full K,V per head, compress to a low-rank latent c_KV and reconstruct K,V on-the-fly during attention
  • Q compression: Query is also compressed through a low-rank bottleneck
  • Decoupled RoPE: Position information is encoded via separate RoPE dimensions that are concatenated with content dimensions, not mixed

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+--------------------------+
| MLA Block x N            |
|  LayerNorm               |
|  MLA Attention:          |
|   h -> W_DKV -> c_KV     |  (KV latent)
|   c_KV -> W_UK -> K_c    |  (content keys)
|   c_KV -> W_UV -> V      |  (values)
|   h -> W_DQ -> c_Q       |  (Q latent)
|   c_Q -> W_UQ -> Q_c     |  (content queries)
|   c_Q -> W_QR -> RoPE    |  (query rope)
|   h -> W_KR -> RoPE      |  (key rope, shared)
|   Q = [Q_c ; Q_r]        |
|   K = [K_c ; K_r]        |
|   score = softmax(QK^T/s) |
|  Residual                |
|  LayerNorm -> FFN        |
|  Residual                |
+--------------------------+
      |
      v
[batch, hidden_size]       (last timestep)

Usage

model = MLA.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 4,
  kv_latent_dim: 64,
  num_layers: 4
)

References

Summary

Types

Options for build/1.

Functions

Build an MLA model for sequence processing.

Build a single MLA transformer block.

Get the output size of an MLA model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:head_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:kv_latent_dim, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:q_latent_dim, pos_integer()}
  | {:rope_dim, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build an MLA model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :head_dim - Dimension per head for content (default: 64)
  • :kv_latent_dim - Compressed KV latent dimension (default: hidden_size / 4)
  • :q_latent_dim - Compressed Q latent dimension (default: hidden_size * 3 / 4)
  • :rope_dim - Decoupled RoPE dimension per head (default: 32)
  • :num_layers - Number of MLA blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :seq_len - Expected sequence length (default: 60)
  • :window_size - Alias for seq_len (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_mla_block(input, opts)

@spec build_mla_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single MLA transformer block.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of an MLA model.