KAT: KAN-Attention Transformer — attention blocks with KAN replacing FFN.
Combines standard multi-head self-attention with Kolmogorov-Arnold Network (KAN) layers as the feed-forward sublayer, replacing the typical MLP FFN. KAN layers use learnable activation functions on edges (basis functions like B-splines, sine, Chebyshev) instead of fixed activations on nodes.
Architecture
Input [batch, seq, embed_dim]
|
v
+-------------------------------------+
| TransformerBlock (per layer): |
| norm -> MultiHead Attention |
| norm -> KAN Layer (replaces FFN) |
+-------------------------------------+
| (repeat num_layers)
v
Final Norm -> Last Timestep
Output [batch, hidden_size]Why KAN Instead of FFN?
| Aspect | Standard FFN | KAN FFN |
|---|---|---|
| Activation | Fixed (ReLU/GELU) on nodes | Learnable on edges |
| Expressiveness | Requires width for accuracy | Learns optimal activation |
| Interpretability | Low | Higher (visualizable) |
| Parameters | O(n^2) | O(n^2 * grid_size) |
Usage
model = KAT.build(
embed_dim: 287,
hidden_size: 256,
num_heads: 4,
grid_size: 8,
basis: :bspline,
num_layers: 4
)References
- Liu et al., "KAN: Kolmogorov-Arnold Networks" (2024)
- Vaswani et al., "Attention Is All You Need" (2017)
Summary
Functions
Build a KAT (KAN-Attention Transformer) model.
Get the output size of a KAT model.
Get recommended defaults for KAT.
Types
@type build_opt() :: {:basis, :bspline | :sine | :chebyshev | :fourier | :rbf} | {:dropout, float()} | {:embed_dim, pos_integer()} | {:grid_size, pos_integer()} | {:hidden_size, pos_integer()} | {:num_heads, pos_integer()} | {:num_layers, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a KAT (KAN-Attention Transformer) model.
Options
:embed_dim- Input embedding dimension (required):hidden_size- Internal hidden dimension (default: 256):num_heads- Number of attention heads (default: 4):grid_size- Number of KAN basis functions per edge (default: 8):basis- KAN basis function::bspline,:sine,:chebyshev,:fourier,:rbf(default: :bspline):num_layers- Number of transformer layers (default: 4):dropout- Dropout rate (default: 0.1):window_size- Sequence length (default: 60)
Returns
An Axon model outputting [batch, hidden_size].
@spec output_size(keyword()) :: pos_integer()
Get the output size of a KAT model.
@spec recommended_defaults() :: keyword()
Get recommended defaults for KAT.