Edifice.Attention.GatedAttention (Edifice v0.2.0)

Copy Markdown View Source

Gated Attention: learned gating over attention output.

Applies a learnable sigmoid gate to attention output:

output = sigmoid(g) * Attention(Q, K, V)

Where g is a learned gate vector (one scalar per hidden dimension). This allows the model to selectively suppress or amplify attention outputs per feature dimension.

Key Innovation

Standard attention outputs are weighted sums that can be noisy. The gate learns which dimensions of the attention output are reliable/useful and which should be dampened. This is similar to gating in LSTMs/GRUs but applied to attention.

Architecture

Input [batch, seq_len, embed_dim]
      |
+------------------------------+
|  Gated Attention Block       |
|                              |
|  Q, K, V projections         |
|         |                    |
|  Standard attention          |
|         |                    |
|  sigmoid(g) * attn_out       |
|         |                    |
|  Output projection           |
+------------------------------+
      |
[batch, seq_len, hidden_size]

Usage

model = GatedAttention.build(
  embed_dim: 256,
  hidden_size: 256,
  num_heads: 4,
  num_layers: 6
)

Reference

  • "Gated Attention Networks" (NeurIPS 2025 Best Paper)

Summary

Types

Options for build/1.

Functions

Build a Gated Attention model.

Build the gated attention layer.

Get the output dimension for a model configuration.

Recommended default configuration.

Types

build_opt()

@type build_opt() ::
  {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:dropout, float()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a Gated Attention model.

Options

  • :embed_dim - Size of input embedding per timestep (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :num_layers - Number of transformer blocks (default: 6)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Expected sequence length for JIT optimization (default: 60)

Returns

An Axon model that outputs [batch, hidden_size] from the last position.

build_gated_attention(input, opts)

@spec build_gated_attention(
  Axon.t(),
  keyword()
) :: Axon.t()

Build the gated attention layer.

Projects to Q, K, V, computes standard attention, then applies learned sigmoid gate to the output.

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output dimension for a model configuration.