Edifice.Feedforward.KAT (Edifice v0.2.0)

Copy Markdown View Source

KAT: KAN-Attention Transformer — attention blocks with KAN replacing FFN.

Combines standard multi-head self-attention with Kolmogorov-Arnold Network (KAN) layers as the feed-forward sublayer, replacing the typical MLP FFN. KAN layers use learnable activation functions on edges (basis functions like B-splines, sine, Chebyshev) instead of fixed activations on nodes.

Architecture

Input [batch, seq, embed_dim]
      |
      v
+-------------------------------------+
|  TransformerBlock (per layer):       |
|    norm -> MultiHead Attention       |
|    norm -> KAN Layer (replaces FFN)  |
+-------------------------------------+
      | (repeat num_layers)
      v
Final Norm -> Last Timestep
Output [batch, hidden_size]

Why KAN Instead of FFN?

AspectStandard FFNKAN FFN
ActivationFixed (ReLU/GELU) on nodesLearnable on edges
ExpressivenessRequires width for accuracyLearns optimal activation
InterpretabilityLowHigher (visualizable)
ParametersO(n^2)O(n^2 * grid_size)

Usage

model = KAT.build(
  embed_dim: 287,
  hidden_size: 256,
  num_heads: 4,
  grid_size: 8,
  basis: :bspline,
  num_layers: 4
)

References

  • Liu et al., "KAN: Kolmogorov-Arnold Networks" (2024)
  • Vaswani et al., "Attention Is All You Need" (2017)

Summary

Types

Options for build/1.

Functions

Build a KAT (KAN-Attention Transformer) model.

Get the output size of a KAT model.

Get recommended defaults for KAT.

Types

build_opt()

@type build_opt() ::
  {:basis, :bspline | :sine | :chebyshev | :fourier | :rbf}
  | {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:grid_size, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_heads, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a KAT (KAN-Attention Transformer) model.

Options

  • :embed_dim - Input embedding dimension (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_heads - Number of attention heads (default: 4)
  • :grid_size - Number of KAN basis functions per edge (default: 8)
  • :basis - KAN basis function: :bspline, :sine, :chebyshev, :fourier, :rbf (default: :bspline)
  • :num_layers - Number of transformer layers (default: 4)
  • :dropout - Dropout rate (default: 0.1)
  • :window_size - Sequence length (default: 60)

Returns

An Axon model outputting [batch, hidden_size].

output_size(opts \\ [])

@spec output_size(keyword()) :: pos_integer()

Get the output size of a KAT model.