Edifice.Attention.Mega (Edifice v0.2.0)

Copy Markdown View Source

Mega: Moving Average Equipped Gated Attention.

Implements the Mega architecture from "Mega: Moving Average Equipped Gated Attention" (Ma et al., ICLR 2023). Mega combines exponential moving averages (EMA) for local context with single-head gated attention for global context, achieving strong performance with sub-quadratic complexity.

Key Innovation: EMA + Gated Attention

Each Mega block has three sub-layers:

  1. EMA sub-layer: Multi-dimensional exponential moving average captures local temporal patterns with learnable decay rates per dimension
  2. Gated attention: Single-head attention with sigmoid gating provides selective global context aggregation
  3. FFN: Standard feed-forward network for feature transformation
Mega Block:
  input -> LayerNorm -> EMA -> residual
        -> LayerNorm -> GatedAttn -> residual
        -> LayerNorm -> FFN -> residual

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-----------------------+
| Input Projection      |
+-----------------------+
      |
      v
+-----------------------+
| Mega Block x N        |
|  EMA Sub-Layer        |
|    alpha = sigmoid(a) |
|    h_t = alpha*h_{t-1}|
|        + (1-alpha)*x_t|
|  Gated Attention      |
|    Q, K, V projections|
|    gate * attn_output |
|  FFN                  |
+-----------------------+
      |
      v
[batch, hidden_size]    (last timestep)

Complexity

OperationStandard AttentionMega
LocalO(L^2)O(L * D_ema) via EMA
GlobalO(L^2 * H)O(L^2) single-head

Usage

model = Mega.build(
  embed_dim: 287,
  hidden_size: 256,
  ema_dim: 16,
  num_layers: 4
)

Reference

Summary

Types

Options for build/1.

Functions

Build a Mega model for sequence processing.

Build a single Mega block with EMA + gated attention + FFN.

Get the output size of a Mega model.

Get recommended defaults.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:ema_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:laplace_attention, boolean()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Mega model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :ema_dim - Dimensionality of EMA expansion (default: 16)
  • :num_layers - Number of Mega blocks (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length (default: 60)
  • :laplace_attention - Use Laplace attention instead of softmax (default: false)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_mega_block(input, opts)

@spec build_mega_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single Mega block with EMA + gated attention + FFN.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a Mega model.