Edifice.Blocks.TransformerBlock (Edifice v0.2.0)

Copy Markdown View Source

Composable transformer block with configurable attention/mixing function.

Implements the standard pre-norm transformer block pattern:

norm -> attention_fn -> residual -> norm -> FFN -> residual

The caller provides the attention/mixing function as a callback, making this block reusable across GQA, FNet, Performer, Nystromformer, LinearTransformer, and any future attention variant.

Architecture

Input
  |
  +---> LayerNorm/RMSNorm -> attention_fn(x) ---+
  |                                              |
  +<----- Residual + Dropout <-------------------+
  |
  +---> LayerNorm/RMSNorm -> FFN(x) ------------+
  |                                              |
  +<----- Residual + Dropout <-------------------+
  |
Output

Usage

# Single block with custom attention
block = TransformerBlock.layer(input,
  attention_fn: fn x, name -> build_my_attention(x, name) end,
  hidden_size: 256,
  name: "block_1"
)

# Stack N blocks
output = TransformerBlock.stack(input, 4,
  attention_fn: fn x, name -> build_my_attention(x, name) end,
  hidden_size: 256,
  name: "transformer"
)

Design

Follows the callback pattern established by Edifice.SSM.Common.build_block/3, where the caller provides the core mixing/attention function and this module handles the surrounding structure (normalization, residuals, FFN, dropout).

Summary

Functions

Build a single transformer block.

Stack N transformer blocks with auto-naming.

Functions

layer(input, opts)

@spec layer(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single transformer block.

Options

  • :attention_fn - Function (input, name) -> Axon.t() that builds the attention/mixing sublayer (required)
  • :hidden_size - Hidden dimension (required)
  • :ffn_type - FFN variant: :standard or :gated (default: :standard)
  • :ffn_expansion - FFN expansion factor (default: 4)
  • :custom_ffn - Custom FFN callback (input, name) -> Axon.t() that replaces the standard FFN sublayer. When provided, :ffn_type and :ffn_expansion are ignored. Used by KAT and other architectures that need non-standard feed-forward networks.
  • :norm - Normalization type: :layer_norm or :rms_norm (default: :layer_norm)
  • :norm_position - Where to normalize: :pre or :post (default: :pre)
  • :dropout - Dropout rate (default: 0.0)
  • :name - Block name prefix (default: "transformer_block")

stack(input, num_layers, opts)

@spec stack(Axon.t(), pos_integer(), keyword()) :: Axon.t()

Stack N transformer blocks with auto-naming.

Options

Same as layer/2 plus:

  • First argument is the input Axon node
  • Second argument is the number of layers to stack

Returns

The output of the final block (same shape as input).