KANBlock#
- class KANBlock(input_dimensions, output_dimensions, spline_order=3, n_knots=10, grid_range=[0, 1], base_function=<class 'torch.nn.modules.activation.SiLU'>, use_base_linear=True, use_bias=True, init_scale_spline=0.01, init_scale_base=1.0)[source]#
Bases:
ModuleThe inner block of the Kolmogorov-Arnold Network (KAN).
The block applies a spline transformation to the input, optionally combined with a linear transformation of a base activation function. The output is aggregated across input dimensions to produce the final output.
See also
Original reference: Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M., Hou T., Tegmark M. (2025). KAN: Kolmogorov-Arnold Networks. DOI: arXiv preprint arXiv:2404.19756.
Initialization of the
KANBlockclass.- Parameters:
input_dimensions (int) – The number of input features.
output_dimensions (int) – The number of output features.
spline_order (int) – The order of each spline basis function. Default is 3 (cubic splines).
n_knots (int) – The number of knots for each spline basis function. Default is 10.
grid_range (list | tuple.) – The range for the spline knots. It must be either a list or a tuple of the form [min, max]. Default is [0, 1].
base_function (torch.nn.Module) – The base activation function to be applied to the input before the linear transformation. Default is
torch.nn.SiLU.use_base_linear (bool) – Whether to include a linear transformation of the base function output. Default is True.
use_bias (bool) – Whether to include a bias term in the output. Default is True.
init_scale_spline (float | int.) – The scale for initializing each spline control points. Default is 1e-2.
init_scale_base (float | int.) – The scale for initializing the base linear weights. Default is 1.0.
- Raises:
ValueError – If
grid_rangeis not of length 2.
- forward(x)[source]#
Forward pass of the Kolmogorov-Arnold block. The input is passed through the spline transformation, optionally combined with a linear transformation of the base function output, and then aggregated across input dimensions to produce the final output.
- Parameters:
x (torch.Tensor | LabelTensor) – The input tensor for the model.
- Returns:
The output tensor of the model.
- Return type: