Spline#

class Spline(order=4, knots=None, control_points=None)[source]#

Bases: Module

Spline model class.

Initialization of the Spline class.

Parameters:
  • order (int) – The order of the spline. Default is 4.

  • knots (torch.Tensor) – The tensor representing knots. If None, the knots will be initialized automatically. Default is None.

  • control_points (torch.Tensor) – The control points. Default is None.

Raises:
  • ValueError – If the order is negative.

  • ValueError – If both knots and control points are None.

  • ValueError – If the knot tensor is not one-dimensional.

basis(x, k, i, t)[source]#

Recursive method to compute the basis functions of the spline.

Parameters:
  • x (torch.Tensor) – The points to be evaluated.

  • k (int) – The spline degree.

  • i (int) – The index of the interval.

  • t (torch.Tensor) – The tensor of knots.

Returns:

The basis functions evaluated at x

Return type:

torch.Tensor

property control_points#

The control points of the spline.

Returns:

The control points.

Return type:

torch.Tensor

property knots#

The knots of the spline.

Returns:

The knots.

Return type:

torch.Tensor

forward(x)[source]#

Forward pass for the Spline model.

Parameters:

x (torch.Tensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor