Multi-Head Latent Attention (MLA) from DeepSeek-V2/V3.
MLA compresses key-value representations into low-rank latent vectors, dramatically reducing KV cache memory while maintaining attention quality. It also uses decoupled Rotary Position Embedding (RoPE) to keep position information separate from compressed content.
Key Innovations
- KV compression: Instead of caching full K,V per head, compress to a
low-rank latent
c_KVand reconstruct K,V on-the-fly during attention - Q compression: Query is also compressed through a low-rank bottleneck
- Decoupled RoPE: Position information is encoded via separate RoPE dimensions that are concatenated with content dimensions, not mixed
Architecture
Input [batch, seq_len, embed_dim]
|
v
+--------------------------+
| MLA Block x N |
| LayerNorm |
| MLA Attention: |
| h -> W_DKV -> c_KV | (KV latent)
| c_KV -> W_UK -> K_c | (content keys)
| c_KV -> W_UV -> V | (values)
| h -> W_DQ -> c_Q | (Q latent)
| c_Q -> W_UQ -> Q_c | (content queries)
| c_Q -> W_QR -> RoPE | (query rope)
| h -> W_KR -> RoPE | (key rope, shared)
| Q = [Q_c ; Q_r] |
| K = [K_c ; K_r] |
| score = softmax(QK^T/s) |
| Residual |
| LayerNorm -> FFN |
| Residual |
+--------------------------+
|
v
[batch, hidden_size] (last timestep)Usage
model = MLA.build(
embed_dim: 287,
hidden_size: 256,
num_heads: 4,
kv_latent_dim: 64,
num_layers: 4
)References
- "DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model" (DeepSeek-AI, 2024)
- arXiv: https://arxiv.org/abs/2405.04434
Summary
Functions
Build an MLA model for sequence processing.
Build a single MLA transformer block.
Get the output size of an MLA model.
Get recommended defaults.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:head_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:kv_latent_dim, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:q_latent_dim, pos_integer()} | {:rope_dim, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build an MLA model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):head_dim- Dimension per head for content (default: 64):kv_latent_dim- Compressed KV latent dimension (default: hidden_size / 4):q_latent_dim- Compressed Q latent dimension (default: hidden_size * 3 / 4):rope_dim- Decoupled RoPE dimension per head (default: 32):num_layers- Number of MLA blocks (default: 4):dropout- Dropout rate (default: 0.1):seq_len- Expected sequence length (default: 60):window_size- Alias for seq_len (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build a single MLA transformer block.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of an MLA model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.