Edifice.Feedforward.KAN (Edifice v0.2.0)

Copy Markdown View Source

KAN: Kolmogorov-Arnold Networks with learnable activation functions.

Implements KAN from "KAN: Kolmogorov-Arnold Networks" (Liu et al., 2024). Based on the Kolmogorov-Arnold representation theorem: any multivariate continuous function can be represented as compositions of univariate functions.

Key Innovation: Learnable Edge Activations

Unlike MLPs with fixed activations on nodes, KAN has learnable activations on edges:

MLP:  y = W2 * sigma(W1 * x)           # Fixed sigma (ReLU, etc.)
KAN:  y = Sum_j Phi_j(x_j)             # Learnable Phi_j per edge

Each edge activation is parameterized as:

Phi(x) = w_base * SiLU(x) + w_spline * spline(x)

Basis Function Options

This implementation supports multiple basis functions:

BasisFormulaParamsSpeed
:bspline (default)Sum c*B_k(x) (cubic B-spline)O(oig)Medium
:sineSum Asin(omegax + phi)O(oig)Fast
:chebyshevSum c*Tn(x)O(oig)Fast
:fourierSum (acos + bsin)O(2oig)Medium

| :rbf | Sum w*exp(-||x-mu||^2/2sigma^2) | O(oig) | Medium |

Architecture

Input [batch, seq_len, embed_dim]
      |
      v
+-------------------------------------+
|       KAN Block                      |
|  LayerNorm -> KAN Layer -> Residual  |
|  LayerNorm -> KAN Layer -> Residual  |
+-------------------------------------+
      | (repeat for num_layers)
      v
Output [batch, hidden_size]

Usage

# Build KAN backbone
model = KAN.build(
  embed_dim: 287,
  hidden_size: 256,
  num_layers: 4,
  grid_size: 8,
  basis: :sine
)

Comparison with MLP

AspectMLPKAN
ActivationFixed on nodesLearnable on edges
InterpretabilityLowHigh (visualizable)
ParametersO(n^2)O(n^2*g) where g=grid
Best forGeneral tasksSymbolic/scientific

References

Summary

Types

Options for build/1.

Functions

Build a KAN model for sequence processing.

Build a single KAN block.

Build a KAN layer with learnable edge activations.

Compute Chebyshev polynomial basis functions.

Default basis function type

Default dropout rate

Default grid size (number of basis functions)

Default hidden dimension

Default number of layers

Epsilon for numerical stability

Get the output size of a KAN model.

Calculate approximate parameter count for a KAN model.

Compute RBF (Radial Basis Function) basis.

Get recommended defaults for sequence processing.

Compute sine basis functions.

Types

build_opt()

@type build_opt() ::
  {:dropout, float()}
  | {:embed_dim, pos_integer()}
  | {:hidden_size, pos_integer()}
  | {:num_layers, pos_integer()}
  | {:seq_len, pos_integer()}
  | {:window_size, pos_integer()}

Options for build/1.

Functions

build(opts \\ [])

@spec build([build_opt()]) :: Axon.t()

Build a KAN model for sequence processing.

Options

  • :embed_dim - Size of input embedding per frame (required)
  • :hidden_size - Internal hidden dimension (default: 256)
  • :num_layers - Number of KAN blocks (default: 4)
  • :grid_size - Number of basis functions per edge (default: 8)
  • :basis - Basis function type: :bspline, :sine, :chebyshev, :fourier, :rbf (default: :bspline)
  • :dropout - Dropout rate (default: 0.0)
  • :window_size - Expected sequence length (default: 60)
  • :base_weight - Weight for base SiLU activation (default: 0.5)

Returns

An Axon model that processes sequences and outputs the last hidden state.

build_kan_block(input, opts \\ [])

@spec build_kan_block(
  Axon.t(),
  keyword()
) :: Axon.t()

Build a single KAN block.

KAN block structure:

  1. LayerNorm -> KAN Layer -> Residual
  2. LayerNorm -> KAN Layer (wider) -> Residual

build_kan_layer(input, out_size, opts \\ [])

@spec build_kan_layer(Axon.t(), pos_integer(), keyword()) :: Axon.t()

Build a KAN layer with learnable edge activations.

KAN layer computes:

y_i = Sum_j Phi_ij(x_j)

Where each Phi_ij is approximated as:

Phi(x) = w_base * SiLU(x) + w_spline * Sum sin(omega*x)

This implementation uses a combination of:

  1. Base activation: SiLU(x) for gradient flow
  2. Learnable activation: Multi-frequency sine basis projected through dense layers

chebyshev_basis(x, arg2)

Compute Chebyshev polynomial basis functions.

ChebyKAN: y = Sum c Tn(x) where T0(x) = 1, T1(x) = x, Tn+1(x) = 2xTn(x) - Tn-1(x)

default_basis()

@spec default_basis() :: atom()

Default basis function type

default_dropout()

@spec default_dropout() :: float()

Default dropout rate

default_grid_size()

@spec default_grid_size() :: pos_integer()

Default grid size (number of basis functions)

default_hidden_size()

@spec default_hidden_size() :: pos_integer()

Default hidden dimension

default_num_layers()

@spec default_num_layers() :: pos_integer()

Default number of layers

eps()

@spec eps() :: float()

Epsilon for numerical stability

output_size(opts \\ [])

@spec output_size(keyword()) :: non_neg_integer()

Get the output size of a KAN model.

param_count(opts)

@spec param_count(keyword()) :: non_neg_integer()

Calculate approximate parameter count for a KAN model.

rbf_basis(x, centers, sigma)

Compute RBF (Radial Basis Function) basis.

y = Sum w exp(-||x - mu||^2 / 2sigma^2)

sine_basis(x, frequencies, phases)

Compute sine basis functions.

SineKAN: y = Sum A sin(omega x + phi)