KAN: Kolmogorov-Arnold Networks with learnable activation functions.
Implements KAN from "KAN: Kolmogorov-Arnold Networks" (Liu et al., 2024). Based on the Kolmogorov-Arnold representation theorem: any multivariate continuous function can be represented as compositions of univariate functions.
Key Innovation: Learnable Edge Activations
Unlike MLPs with fixed activations on nodes, KAN has learnable activations on edges:
MLP: y = W2 * sigma(W1 * x) # Fixed sigma (ReLU, etc.)
KAN: y = Sum_j Phi_j(x_j) # Learnable Phi_j per edgeEach edge activation is parameterized as:
Phi(x) = w_base * SiLU(x) + w_spline * spline(x)Basis Function Options
This implementation supports multiple basis functions:
| Basis | Formula | Params | Speed |
|---|---|---|---|
:bspline (default) | Sum c*B_k(x) (cubic B-spline) | O(oig) | Medium |
:sine | Sum Asin(omegax + phi) | O(oig) | Fast |
:chebyshev | Sum c*Tn(x) | O(oig) | Fast |
:fourier | Sum (acos + bsin) | O(2oig) | Medium |
| :rbf | Sum w*exp(-||x-mu||^2/2sigma^2) | O(oig) | Medium |
Architecture
Input [batch, seq_len, embed_dim]
|
v
+-------------------------------------+
| KAN Block |
| LayerNorm -> KAN Layer -> Residual |
| LayerNorm -> KAN Layer -> Residual |
+-------------------------------------+
| (repeat for num_layers)
v
Output [batch, hidden_size]Usage
# Build KAN backbone
model = KAN.build(
embed_dim: 287,
hidden_size: 256,
num_layers: 4,
grid_size: 8,
basis: :sine
)Comparison with MLP
| Aspect | MLP | KAN |
|---|---|---|
| Activation | Fixed on nodes | Learnable on edges |
| Interpretability | Low | High (visualizable) |
| Parameters | O(n^2) | O(n^2*g) where g=grid |
| Best for | General tasks | Symbolic/scientific |
References
Summary
Functions
Build a KAN model for sequence processing.
Build a single KAN block.
Build a KAN layer with learnable edge activations.
Compute Chebyshev polynomial basis functions.
Default basis function type
Default dropout rate
Default grid size (number of basis functions)
Default hidden dimension
Default number of layers
Epsilon for numerical stability
Get the output size of a KAN model.
Calculate approximate parameter count for a KAN model.
Compute RBF (Radial Basis Function) basis.
Get recommended defaults for sequence processing.
Compute sine basis functions.
Types
@type build_opt() :: {:dropout, float()} | {:embed_dim, pos_integer()} | {:hidden_size, pos_integer()} | {:num_layers, pos_integer()} | {:seq_len, pos_integer()} | {:window_size, pos_integer()}
Options for build/1.
Functions
Build a KAN model for sequence processing.
Options
:embed_dim- Size of input embedding per frame (required):hidden_size- Internal hidden dimension (default: 256):num_layers- Number of KAN blocks (default: 4):grid_size- Number of basis functions per edge (default: 8):basis- Basis function type: :bspline, :sine, :chebyshev, :fourier, :rbf (default: :bspline):dropout- Dropout rate (default: 0.0):window_size- Expected sequence length (default: 60):base_weight- Weight for base SiLU activation (default: 0.5)
Returns
An Axon model that processes sequences and outputs the last hidden state.
Build a single KAN block.
KAN block structure:
- LayerNorm -> KAN Layer -> Residual
- LayerNorm -> KAN Layer (wider) -> Residual
@spec build_kan_layer(Axon.t(), pos_integer(), keyword()) :: Axon.t()
Build a KAN layer with learnable edge activations.
KAN layer computes:
y_i = Sum_j Phi_ij(x_j)Where each Phi_ij is approximated as:
Phi(x) = w_base * SiLU(x) + w_spline * Sum sin(omega*x)This implementation uses a combination of:
- Base activation: SiLU(x) for gradient flow
- Learnable activation: Multi-frequency sine basis projected through dense layers
Compute Chebyshev polynomial basis functions.
ChebyKAN: y = Sum c Tn(x) where T0(x) = 1, T1(x) = x, Tn+1(x) = 2xTn(x) - Tn-1(x)
@spec default_basis() :: atom()
Default basis function type
@spec default_dropout() :: float()
Default dropout rate
@spec default_grid_size() :: pos_integer()
Default grid size (number of basis functions)
@spec default_num_layers() :: pos_integer()
Default number of layers
@spec eps() :: float()
Epsilon for numerical stability
@spec output_size(keyword()) :: non_neg_integer()
Get the output size of a KAN model.
@spec param_count(keyword()) :: non_neg_integer()
Calculate approximate parameter count for a KAN model.
Compute RBF (Radial Basis Function) basis.
y = Sum w exp(-||x - mu||^2 / 2sigma^2)
@spec recommended_defaults() :: keyword()
Get recommended defaults for sequence processing.
Compute sine basis functions.
SineKAN: y = Sum A sin(omega x + phi)