Composable transformer block with configurable attention/mixing function.
Implements the standard pre-norm transformer block pattern:
norm -> attention_fn -> residual -> norm -> FFN -> residualThe caller provides the attention/mixing function as a callback, making this block reusable across GQA, FNet, Performer, Nystromformer, LinearTransformer, and any future attention variant.
Architecture
Input
|
+---> LayerNorm/RMSNorm -> attention_fn(x) ---+
| |
+<----- Residual + Dropout <-------------------+
|
+---> LayerNorm/RMSNorm -> FFN(x) ------------+
| |
+<----- Residual + Dropout <-------------------+
|
OutputUsage
# Single block with custom attention
block = TransformerBlock.layer(input,
attention_fn: fn x, name -> build_my_attention(x, name) end,
hidden_size: 256,
name: "block_1"
)
# Stack N blocks
output = TransformerBlock.stack(input, 4,
attention_fn: fn x, name -> build_my_attention(x, name) end,
hidden_size: 256,
name: "transformer"
)Design
Follows the callback pattern established by Edifice.SSM.Common.build_block/3,
where the caller provides the core mixing/attention function and this module
handles the surrounding structure (normalization, residuals, FFN, dropout).
Summary
Functions
Build a single transformer block.
Options
:attention_fn- Function(input, name) -> Axon.t()that builds the attention/mixing sublayer (required):hidden_size- Hidden dimension (required):ffn_type- FFN variant::standardor:gated(default: :standard):ffn_expansion- FFN expansion factor (default: 4):custom_ffn- Custom FFN callback(input, name) -> Axon.t()that replaces the standard FFN sublayer. When provided,:ffn_typeand:ffn_expansionare ignored. Used by KAT and other architectures that need non-standard feed-forward networks.:norm- Normalization type::layer_normor:rms_norm(default: :layer_norm):norm_position- Where to normalize::preor:post(default: :pre):dropout- Dropout rate (default: 0.0):name- Block name prefix (default: "transformer_block")
@spec stack(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Stack N transformer blocks with auto-naming.
Options
Same as layer/2 plus:
- First argument is the input Axon node
- Second argument is the number of layers to stack
Returns
The output of the final block (same shape as input).