KolmogorovArnoldNetwork#

class KolmogorovArnoldNetwork(layers, spline_order=3, n_knots=10, grid_range=[-1, 1], base_function=<class 'torch.nn.modules.activation.SiLU'>, use_base_linear=True, use_bias=True, init_scale_spline=0.01, init_scale_base=1.0)[source]#

Bases: Module

Implementation of Kolmogorov-Arnold Network (KAN).

The model consists of a sequence of KAN blocks, where each block applies a spline transformation to the input, optionally combined with a linear transformation of a base activation function.

See also

Original reference: Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., Hou T., Tegmark M. (2025). KAN: Kolmogorov-Arnold Networks. DOI: arXiv preprint arXiv:2404.19756.

Initialization of the KolmogorovArnoldNetwork class.

Parameters:
  • layers (list | tuple.) – A list of integers specifying the sizes of each layer, including input and output dimensions.

  • spline_order (int) – The order of each spline basis function. Default is 3 (cubic splines).

  • n_knots (int) – The number of knots for each spline basis function. Default is 3.

  • grid_range (list | tuple.) – The range for the spline knots. It must be either a list or a tuple of the form [min, max]. Default is [0, 1].

  • base_function (torch.nn.Module) – The base activation function to be applied to the input before the linear transformation. Default is torch.nn.SiLU.

  • use_base_linear (bool) – Whether to include a linear transformation of the base function output. Default is True.

  • use_bias (bool) – Whether to include a bias term in the output. Default is True.

  • init_scale_spline (float | int.) – The scale for initializing each spline control points. Default is 1e-2.

  • init_scale_base (float | int.) – The scale for initializing the base linear weights. Default is 1.0.

Raises:

ValueError – If grid_range is not of length 2.

forward(x)[source]#

Forward pass of the KolmogorovArnoldNetwork model. It passes the input through each KAN block in the network and returns the final output.

Parameters:

x (torch.Tensor | LabelTensor) – The input tensor for the model.

Returns:

The output tensor of the model.

Return type:

torch.Tensor | LabelTensor