viva_tensor/flash_attention

Flash Attention - O(n) Memory Attention

REVOLUCIONÁRIO: Reduz memória de O(n²) para O(n)! https://arxiv.org/abs/2205.14135 (Tri Dao, 2022)

PROBLEMA:

SOLUÇÃO:

RESULTADO: 2-4x mais rápido, O(n) memória!

Types

Configuração Flash Attention

pub type FlashConfig {
  FlashConfig(
    block_q: Int,
    block_kv: Int,
    scale: Float,
    causal: Bool,
  )
}

Constructors

  • FlashConfig(
      block_q: Int,
      block_kv: Int,
      scale: Float,
      causal: Bool,
    )

    Arguments

    block_q

    Tamanho do tile para Q (linhas)

    block_kv

    Tamanho do tile para KV (colunas)

    scale

    Scaling factor (1/sqrt(d))

    causal

    Usar causal mask

Resultado de Flash Attention

pub type FlashResult {
  FlashResult(
    output: tensor.Tensor,
    memory_bytes: Int,
    memory_saved_percent: Float,
  )
}

Constructors

  • FlashResult(
      output: tensor.Tensor,
      memory_bytes: Int,
      memory_saved_percent: Float,
    )

    Arguments

    output

    Tensor de saída

    memory_bytes

    Memória usada (bytes)

    memory_saved_percent

    Memória economizada vs naive (%)

Estatísticas online para softmax

pub type OnlineStats {
  OnlineStats(
    max_val: Float,
    sum_exp: Float,
    output: List(Float),
  )
}

Constructors

  • OnlineStats(max_val: Float, sum_exp: Float, output: List(Float))

    Arguments

    max_val

    Máximo corrente (para estabilidade numérica)

    sum_exp

    Soma exponenciada corrente

    output

    Output acumulado

Values

pub fn benchmark_flash_attention() -> Nil
pub fn causal_config(head_dim: Int) -> FlashConfig

Config para causal (autoregressive)

pub fn default_config(head_dim: Int) -> FlashConfig

Configuração padrão

pub fn flash_attention(
  q: tensor.Tensor,
  k: tensor.Tensor,
  v: tensor.Tensor,
  config: FlashConfig,
) -> FlashResult

Flash Attention: processa em tiles, nunca materializa n×n

pub fn main() -> Nil
pub fn naive_attention(
  q: tensor.Tensor,
  k: tensor.Tensor,
  v: tensor.Tensor,
  scale: Float,
) -> #(tensor.Tensor, Int)

Attention ingênua: O(n²) memória scores = Q @ K^T attn = softmax(scores * scale) out = attn @ V

Search Document