viva_tensor/flash_attention
Flash Attention - O(n) Memory Attention
REVOLUCIONÁRIO: Reduz memória de O(n²) para O(n)! https://arxiv.org/abs/2205.14135 (Tri Dao, 2022)
PROBLEMA:
- Attention padrão: Q @ K^T @ V = O(n²) memória
- Para n=8192 (contexto longo): 8192² = 67M elementos = 256MB por cabeça!
- 32 cabeças = 8GB só para attention scores
SOLUÇÃO:
- Processar em TILES (blocos)
- Nunca materializar matriz n×n completa
- Online softmax: atualiza estatísticas incrementalmente
RESULTADO: 2-4x mais rápido, O(n) memória!
Types
Configuração Flash Attention
pub type FlashConfig {
FlashConfig(
block_q: Int,
block_kv: Int,
scale: Float,
causal: Bool,
)
}
Constructors
-
FlashConfig( block_q: Int, block_kv: Int, scale: Float, causal: Bool, )Arguments
- block_q
-
Tamanho do tile para Q (linhas)
- block_kv
-
Tamanho do tile para KV (colunas)
- scale
-
Scaling factor (1/sqrt(d))
- causal
-
Usar causal mask
Resultado de Flash Attention
pub type FlashResult {
FlashResult(
output: tensor.Tensor,
memory_bytes: Int,
memory_saved_percent: Float,
)
}
Constructors
-
FlashResult( output: tensor.Tensor, memory_bytes: Int, memory_saved_percent: Float, )Arguments
- output
-
Tensor de saída
- memory_bytes
-
Memória usada (bytes)
- memory_saved_percent
-
Memória economizada vs naive (%)
Estatísticas online para softmax
pub type OnlineStats {
OnlineStats(
max_val: Float,
sum_exp: Float,
output: List(Float),
)
}
Constructors
-
OnlineStats(max_val: Float, sum_exp: Float, output: List(Float))Arguments
- max_val
-
Máximo corrente (para estabilidade numérica)
- sum_exp
-
Soma exponenciada corrente
- output
-
Output acumulado
Values
pub fn benchmark_flash_attention() -> Nil
pub fn causal_config(head_dim: Int) -> FlashConfig
Config para causal (autoregressive)
pub fn flash_attention(
q: tensor.Tensor,
k: tensor.Tensor,
v: tensor.Tensor,
config: FlashConfig,
) -> FlashResult
Flash Attention: processa em tiles, nunca materializa n×n
pub fn naive_attention(
q: tensor.Tensor,
k: tensor.Tensor,
v: tensor.Tensor,
scale: Float,
) -> #(tensor.Tensor, Int)
Attention ingênua: O(n²) memória scores = Q @ K^T attn = softmax(scores * scale) out = attn @ V