Mega: Moving Average Equipped Gated Attention.
Implements the Mega architecture from "Mega: Moving Average Equipped Gated Attention" (Ma et al., ICLR 2023). Mega combines exponential moving averages (EMA) for local context with single-head gated attention for global context, achieving strong performance with sub-quadratic complexity.
Key Innovation: EMA + Gated Attention
Each Mega block has three sub-layers:
- EMA sub-layer: Multi-dimensional exponential moving average captures local temporal patterns with learnable decay rates per dimension
- Gated attention: Single-head attention with sigmoid gating provides selective global context aggregation
- FFN: Standard feed-forward network for feature transformation
Mega Block:
input -> LayerNorm -> EMA -> residual
-> LayerNorm -> GatedAttn -> residual
-> LayerNorm -> FFN -> residualArchitecture
Input [batch, seq_len, embed_dim]
|
v
+-----------------------+
| Input Projection |
+-----------------------+
|
v
+-----------------------+
| Mega Block x N |
| EMA Sub-Layer |
| alpha = sigmoid(a) |
| h_t = alpha*h_{t-1}|
| + (1-alpha)*x_t|
| Gated Attention |
| Q, K, V projections|
| gate * attn_output |
| FFN |
+-----------------------+
|
v
[batch, hidden_size] (last timestep)Complexity
| Operation | Standard Attention | Mega |
|---|---|---|
| Local | O(L^2) | O(L * D_ema) via EMA |
| Global | O(L^2 * H) | O(L^2) single-head |
Usage
model = Mega.build(
embed_dim: 287,
hidden_size: 256,
ema_dim: 16,
num_layers: 4
)Reference
- Paper: "Mega: Moving Average Equipped Gated Attention"
- arXiv: https://arxiv.org/abs/2209.10655
Summary
Functions
Build a Mega model for sequence processing.
Build a single Mega block with EMA + gated attention + FFN.
Get the output size of a Mega model.
Get recommended defaults.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:ema_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:laplace_attention, boolean()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Mega model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):ema_dim- Dimensionality of EMA expansion (default: 16):num_layers- Number of Mega blocks (default: 4):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length (default: 60):laplace_attention- Use Laplace attention instead of softmax (default: false)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build a single Mega block with EMA + gated attention + FFN.
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a Mega model.
@spec recommended_defaults() :: keyword()
Get recommended defaults.