"""Module for the Kolmogorov-Arnold Network block."""
import torch
from pina._src.model.vectorized_spline import VectorizedSpline
from pina._src.core.utils import check_consistency, check_positive_integer
[docs]
class KANBlock(torch.nn.Module):
"""
The inner block of the Kolmogorov-Arnold Network (KAN).
The block applies a spline transformation to the input, optionally combined
with a linear transformation of a base activation function. The output is
aggregated across input dimensions to produce the final output.
.. seealso::
**Original reference**:
Liu Z., Wang Y., Vaidya S., Ruehle F., Halverson J., Soljacic M.,
Hou T., Tegmark M. (2025).
*KAN: Kolmogorov-Arnold Networks*.
DOI: `arXiv preprint arXiv:2404.19756.
<https://arxiv.org/abs/2404.19756>`_
"""
def __init__(
self,
input_dimensions,
output_dimensions,
spline_order=3,
n_knots=10,
grid_range=[0, 1],
base_function=torch.nn.SiLU,
use_base_linear=True,
use_bias=True,
init_scale_spline=1e-2,
init_scale_base=1.0,
):
"""
Initialization of the :class:`KANBlock` class.
:param int input_dimensions: The number of input features.
:param int output_dimensions: The number of output features.
:param int spline_order: The order of each spline basis function.
Default is 3 (cubic splines).
:param int n_knots: The number of knots for each spline basis function.
Default is 10.
:param grid_range: The range for the spline knots. It must be either a
list or a tuple of the form [min, max]. Default is [0, 1].
:type grid_range: list | tuple.
:param torch.nn.Module base_function: The base activation function to be
applied to the input before the linear transformation. Default is
:class:`torch.nn.SiLU`.
:param bool use_base_linear: Whether to include a linear transformation
of the base function output. Default is True.
:param bool use_bias: Whether to include a bias term in the output.
Default is True.
:param init_scale_spline: The scale for initializing each spline
control points. Default is 1e-2.
:type init_scale_spline: float | int.
:param init_scale_base: The scale for initializing the base linear
weights. Default is 1.0.
:type init_scale_base: float | int.
:raises ValueError: If ``grid_range`` is not of length 2.
"""
super().__init__()
# Check consistency
check_consistency(base_function, torch.nn.Module, subclass=True)
check_positive_integer(input_dimensions, strict=True)
check_positive_integer(output_dimensions, strict=True)
check_positive_integer(spline_order, strict=True)
check_positive_integer(n_knots, strict=True)
check_consistency(use_base_linear, bool)
check_consistency(use_bias, bool)
check_consistency(init_scale_spline, (int, float))
check_consistency(init_scale_base, (int, float))
check_consistency(grid_range, (int, float))
# Raise error if grid_range is not valid
if len(grid_range) != 2:
raise ValueError("Grid must be a list or tuple with two elements.")
# Knots for the spline basis functions
initial_knots = torch.ones(spline_order) * grid_range[0]
final_knots = torch.ones(spline_order) * grid_range[1]
# Number of internal knots
n_internal = max(0, n_knots - 2 * spline_order)
# Internal knots are uniformly spaced in the grid range
internal_knots = torch.linspace(
grid_range[0], grid_range[1], n_internal + 2
)[1:-1]
# Define the knots
knots = torch.cat((initial_knots, internal_knots, final_knots))
knots = knots.unsqueeze(0).repeat(input_dimensions, 1)
# Define the control points for the spline basis functions
control_points = (
torch.randn(
input_dimensions,
output_dimensions,
knots.shape[-1] - spline_order,
)
* init_scale_spline
)
# Define the vectorized spline module
self.spline = VectorizedSpline(
order=spline_order, knots=knots, control_points=control_points
)
# Initialize the base function
self.base_function = base_function()
# Initialize the base linear weights if needed
if use_base_linear:
self.base_weight = torch.nn.Parameter(
torch.randn(output_dimensions, input_dimensions)
* (init_scale_base / (input_dimensions**0.5))
)
else:
self.register_parameter("base_weight", None)
# Initialize the bias term if needed
if use_bias:
self.bias = torch.nn.Parameter(torch.zeros(output_dimensions))
else:
self.register_parameter("bias", None)
[docs]
def forward(self, x):
"""
Forward pass of the Kolmogorov-Arnold block. The input is passed through
the spline transformation, optionally combined with a linear
transformation of the base function output, and then aggregated across
input dimensions to produce the final output.
:param x: The input tensor for the model.
:type x: torch.Tensor | LabelTensor
:return: The output tensor of the model.
:rtype: torch.Tensor | LabelTensor
"""
y = self.spline(x)
if self.base_weight is not None:
base_x = self.base_function(x)
base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight)
y = y + base_out
# aggregate contributions from all input dimensions
y = y.sum(dim=1)
if self.bias is not None:
y = y + self.bias
return y