Gated Attention: learned gating over attention output.
Applies a learnable sigmoid gate to attention output:
output = sigmoid(g) * Attention(Q, K, V)Where g is a learned gate vector (one scalar per hidden dimension).
This allows the model to selectively suppress or amplify attention outputs
per feature dimension.
Key Innovation
Standard attention outputs are weighted sums that can be noisy. The gate learns which dimensions of the attention output are reliable/useful and which should be dampened. This is similar to gating in LSTMs/GRUs but applied to attention.
Architecture
Input [batch, seq_len, embed_dim]
|
+------------------------------+
| Gated Attention Block |
| |
| Q, K, V projections |
| | |
| Standard attention |
| | |
| sigmoid(g) * attn_out |
| | |
| Output projection |
+------------------------------+
|
[batch, seq_len, hidden_size]Usage
model = GatedAttention.build(
embed_dim: 256,
hidden_size: 256,
num_heads: 4,
num_layers: 6
)Reference
- "Gated Attention Networks" (NeurIPS 2025 Best Paper)
Summary
Functions
Build a Gated Attention model.
Build the gated attention layer.
Get the output dimension for a model configuration.
Recommended default configuration.
Types
@type build_opt() :: {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:dropout, float()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a Gated Attention model.
Options
:embed_dim- Size of input embedding per timestep (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):num_layers- Number of transformer blocks (default: 6):dropout- Dropout rate (default: 0.1):window_size- Expected sequence length for JIT optimization (default: 60)
Returns
An Axon model that outputs [batch, hidden_size] from the last position.
Build the gated attention layer.
Projects to Q, K, V, computes standard attention, then applies learned sigmoid gate to the output.
@spec output_size(keyword()) :: non_neg_integer()
Get the output dimension for a model configuration.
@spec recommended_defaults() :: keyword()
Recommended default configuration.