Source code for pina._src.model.kolmogorov_arnold_network
import torch
from pina._src.model.block.kan_block import KANBlock
from pina._src.core.utils import check_consistency
[docs]
class KolmogorovArnoldNetwork(torch.nn.Module):
"""
Implementation of Kolmogorov-Arnold Network (KAN).
The model consists of a sequence of KAN blocks, where each block applies a
spline transformation to the input, optionally combined with a linear
transformation of a base activation function.
.. seealso::
**Original reference**:
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
Hou T., Tegmark M. (2025).
*KAN: Kolmogorov-Arnold Networks*.
DOI: `arXiv preprint arXiv:2404.19756.
<https://arxiv.org/abs/2404.19756>`_
"""
def __init__(
self,
layers,
spline_order=3,
n_knots=10,
grid_range=[-1, 1],
base_function=torch.nn.SiLU,
use_base_linear=True,
use_bias=True,
init_scale_spline=1e-2,
init_scale_base=1.0,
):
"""
Initialization of the :class:`KolmogorovArnoldNetwork` class.
:param layers: A list of integers specifying the sizes of each layer,
including input and output dimensions.
:type layers: list | tuple.
:param int spline_order: The order of each spline basis function.
Default is 3 (cubic splines).
:param int n_knots: The number of knots for each spline basis function.
Default is 3.
:param grid_range: The range for the spline knots. It must be either a
list or a tuple of the form [min, max]. Default is [0, 1].
:type grid_range: list | tuple.
:param torch.nn.Module base_function: The base activation function to be
applied to the input before the linear transformation. Default is
:class:`torch.nn.SiLU`.
:param bool use_base_linear: Whether to include a linear transformation
of the base function output. Default is True.
:param bool use_bias: Whether to include a bias term in the output.
Default is True.
:param init_scale_spline: The scale for initializing each spline
control points. Default is 1e-2.
:type init_scale_spline: float | int.
:param init_scale_base: The scale for initializing the base linear
weights. Default is 1.0.
:type init_scale_base: float | int.
:raises ValueError: If ``grid_range`` is not of length 2.
"""
super().__init__()
# Check consistency -- all other checks are performed in KANBlock
check_consistency(layers, int)
if len(layers) < 2:
raise ValueError(
"`Provide at least two elements for layers (input and output)."
)
# Initialize KAN blocks
self.kan_layers = torch.nn.ModuleList(
[
KANBlock(
input_dimensions=layers[i],
output_dimensions=layers[i + 1],
spline_order=spline_order,
n_knots=n_knots,
grid_range=grid_range,
base_function=base_function,
use_base_linear=use_base_linear,
use_bias=use_bias,
init_scale_spline=init_scale_spline,
init_scale_base=init_scale_base,
)
for i in range(len(layers) - 1)
]
)
[docs]
def forward(self, x):
"""
Forward pass of the KolmogorovArnoldNetwork model. It passes the input
through each KAN block in the network and returns the final output.
:param x: The input tensor for the model.
:type x: torch.Tensor | LabelTensor
:return: The output tensor of the model.
:rtype: torch.Tensor | LabelTensor
"""
for layer in self.kan_layers:
x = layer(x)
return x