KolmogorovArnoldNetwork#
- class KolmogorovArnoldNetwork(layers, spline_order=3, n_knots=10, grid_range=[-1, 1], base_function=<class 'torch.nn.modules.activation.SiLU'>, use_base_linear=True, use_bias=True, init_scale_spline=0.01, init_scale_base=1.0)[source]#
Bases:
ModuleImplementation of Kolmogorov-Arnold Network (KAN).
The model consists of a sequence of KAN blocks, where each block applies a spline transformation to the input, optionally combined with a linear transformation of a base activation function.
See also
Original reference: Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., Hou T., Tegmark M. (2025). KAN: Kolmogorov-Arnold Networks. DOI: arXiv preprint arXiv:2404.19756.
Initialization of the
KolmogorovArnoldNetworkclass.- Parameters:
layers (list | tuple.) – A list of integers specifying the sizes of each layer, including input and output dimensions.
spline_order (int) – The order of each spline basis function. Default is 3 (cubic splines).
n_knots (int) – The number of knots for each spline basis function. Default is 3.
grid_range (list | tuple.) – The range for the spline knots. It must be either a list or a tuple of the form [min, max]. Default is [0, 1].
base_function (torch.nn.Module) – The base activation function to be applied to the input before the linear transformation. Default is
torch.nn.SiLU.use_base_linear (bool) – Whether to include a linear transformation of the base function output. Default is True.
use_bias (bool) – Whether to include a bias term in the output. Default is True.
init_scale_spline (float | int.) – The scale for initializing each spline control points. Default is 1e-2.
init_scale_base (float | int.) – The scale for initializing the base linear weights. Default is 1.0.
- Raises:
ValueError – If
grid_rangeis not of length 2.
- forward(x)[source]#
Forward pass of the KolmogorovArnoldNetwork model. It passes the input through each KAN block in the network and returns the final output.
- Parameters:
x (torch.Tensor | LabelTensor) – The input tensor for the model.
- Returns:
The output tensor of the model.
- Return type: