Spline#
- class Spline(order=4, knots=None, control_points=None)[source]#
Bases:
ModuleThe univariate B-Spline curve model class.
A univariate B-spline curve of order \(k\) is a parametric curve defined as a linear combination of B-spline basis functions and control points:
\[S(x) = \sum_{i=1}^{n} B_{i,k}(x) C_i, \quad x \in [x_1, x_m]\]where:
\(C \in \mathbb{R}^n\) are the learnable control coefficients. Its entries \(C_i\) influence the shape of the curve but are not generally interpolated, except under certain knot multiplicities.
\(B_{i,k}(x)\) are the B-spline basis functions of order \(k\), i.e., piecewise polynomials of degree \(k-1\) with support on the interval \([x_i, x_{i+k}]\).
\(X = \{ x_1, x_2, \dots, x_m \}\) is the non-decreasing knot vector.
If the first and last knots are repeated \(k\) times, then the curve interpolates the first and last control coefficients.
Note
The curve is forced to be zero outside the interval defined by the first and last knots.
- Example:
>>> from pina.model import Spline >>> import torch
>>> knots1 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) >>> spline1 = Spline(order=3, knots=knots1, control_points=None)
>>> knots2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto"} >>> spline2 = Spline(order=3, knots=knots2, control_points=None)
>>> knots3 = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0]) >>> control_points3 = torch.tensor([0.0, 1.0, 3.0, 2.0]) >>> spline3 = Spline(order=3, knots=knots3, control_points=control_points3)
Initialization of the
Splineclass.- Parameters:
order (int) – The order of the spline. The corresponding basis functions are polynomials of degree
order - 1. Default is 4.knots (torch.Tensor | dict) – The knots of the spline. If a tensor is provided, knots are set directly from the tensor. If a dictionary is provided, it must contain the keys
"n","min","max", and"mode". Here,"n"specifies the number of knots,"min"and"max"define the interval, and"mode"selects the sampling strategy. The supported modes are"uniform", where the knots are evenly spaced over \([min, max]\), and"auto", where knots are constructed to ensure that the spline interpolates the first and last control points. In this case, the number of knots is adjusted if \(n < 2 * order\). If None is given, knots are initialized automatically over \([0, 1]\) ensuring interpolation of the first and last control points. Default is None.control_points (torch.Tensor) – The control points of the spline. If None, they are initialized as learnable parameters with an initial value of zero. Default is None.
- Raises:
AssertionError – If
orderis not a positive integer.ValueError – If
knotsis neither a torch.Tensor nor a dictionary, when provided.ValueError – If
control_pointsis not a torch.Tensor, when provided.ValueError – If both
knotsandcontrol_pointsare None.ValueError – If
knotsis not one-dimensional.ValueError – If
control_pointsis not one-dimensional.ValueError – If the number of
knotsis not equal to the sum oforderand the number ofcontrol_points.UserWarning – If the number of control points is lower than the order, resulting in a degenerate spline.
- basis(x, collection=False)[source]#
Compute the basis functions for the spline using an iterative approach. This is a vectorized implementation based on the Cox-de Boor recursion.
- Parameters:
x (torch.Tensor) – The points to be evaluated.
collection (bool) – If True, returns a list of basis functions for all orders up to the spline order. Default is False.
- Raises:
ValueError – If
collectionis not a boolean.- Returns:
The basis functions evaluated at x.
- Return type:
- forward(x)[source]#
Forward pass for the
Splinemodel.- Parameters:
x (torch.Tensor | LabelTensor) – The input tensor.
- Returns:
The output tensor.
- Return type:
- derivative(x, degree)[source]#
Compute the
degree-th derivative of the spline at given points.- Parameters:
x (torch.Tensor | LabelTensor) – The input tensor.
degree (int) – The derivative degree to compute.
- Raises:
ValueError – If
degreeis not an integer.- Returns:
The derivative tensor.
- Return type:
- property control_points#
The control points of the spline.
- Returns:
The control points.
- Return type:
- property knots#
The knots of the spline.
- Returns:
The knots.
- Return type: