VectorizedSpline#

class VectorizedSpline(order=4, knots=None, control_points=None, aggregate_output=None)[source]#

Bases: Module

The vectorized B-spline model class.

A VectorizedSpline represents a vector spline, i.e., a collection of independent univariate B-splines evaluated in parallel. Each univariate spline has its own knot vector and its own control points, and acts on one input feature.

Given s univariate splines, the vector spline maps an input \(x = (x^{(1)}, \dots, x^{(s)}) \in \mathbb{R}^s\) to an output obtained by evaluating each univariate spline on its corresponding scalar input \(x^{(j)}\).

For the \(j\)-th univariate spline of order \(k\), the output is defined as

\[S^{(j)}(x^{(j)}) = \sum_{i=1}^{n_j} B_{i,k}^{(j)}(x^{(j)}) C_i^{(j)},\]

where:

  • \(C^{(j)}\) are the control points of the \(j\)-th univariate spline. In the scalar-output case, \(C^{(j)} \in \mathbb{R}^{n_j}\). More generally, each univariate spline may have output dimension \(o\), so \(C^{(j)} \in \mathbb{R}^{o \times n_j}\).

  • \(B_{i,k}^{(j)}(x)\) are the B-spline basis functions of order \(k\), i.e., piecewise polynomials of degree \(k-1\), associated with the knot vector of the \(j\)-th univariate spline.

  • \(X^{(j)} = \{x_1^{(j)}, x_2^{(j)}, \dots, x_{m_j}^{(j)}\}\) is the non-decreasing knot vector of the \(j\)-th univariate spline.

If the first and last knots of a given univariate spline are repeated \(k\) times, then that univariate spline interpolates its first and last control points.

The full vector spline evaluates all univariate splines in parallel. If each univariate spline has output dimension \(o\), then before optional aggregation the output has shape [batch, s, o].

Note

Each univariate spline is forced to be zero outside the interval defined by the first and last knots of its own knot vector.

Note

This class does not represent a single multivariate spline \(\mathbb{R}^s \to \mathbb{R}^o\) with a genuinely multivariate basis. Instead, it represents a vector of splines built from s independent univariate splines, one for each input feature.

Note

When using the derivative() method of this class, derivatives are computed directly in vectorized form and returned with the correct shape. In contrast, when relying on autograd, derivatives must be computed separately for each output dimension of each univariate spline and then combined, since autograd does not natively handle this vectorized structure.

Example:

>>> from pina.model import VectorizedSpline
>>> import torch
>>> knt1 = torch.tensor([
...     [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
...     [0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0],
... ])
>>> spline1 = VectorizedSpline(order=3, knots=knt1, control_points=None)
>>> knt2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto", "n_splines": 2}
>>> spline2 = VectorizedSpline(order=3, knots=knt2, control_points=None)
>>> knt3 = torch.tensor([
...     [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
...     [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
... ])
>>> ctrl3 = torch.tensor([
...     [0.0, 1.0, 3.0, 2.0],
...     [1.0, 0.0, 2.0, 1.0],
... ])
>>> spline3 = VectorizedSpline(order=3, knots=knt3, control_points=ctrl3)

Initialization of the VectorizedSpline class.

Parameters:
  • order (int) – The order of each univariate spline. The corresponding basis functions are polynomials of degree order - 1. Default is 4.

  • knots (torch.Tensor | dict) – The knots of the spline. If a tensor is provided, it must have shape [s, n], where s is the number of univariate splines and n is the number of knots per univariate spline. If a dictionary is provided, it must contain the keys "n", "min", "max", "mode", and "n_splines". Here, "n" specifies the number of knots for each univariate spline, "min" and "max" define the interval, "mode" selects the sampling strategy, and "n_splines" specifies the number of univariate splines. The supported modes are "uniform", where the knots are evenly spaced over \([min, max]\), and "auto", where knots are constructed to ensure that each univariate spline interpolates the first and last control points. In this case, the number of knots is adjusted if \(n < 2 * order\). If None is given, knots are initialized automatically over \([0, 1]\) ensuring interpolation of the first and last control points. Default is None.

  • control_points (torch.Tensor) – The control points tensor. The tensor must be either of shape [s, o, c] or [s, c], where each univariate spline has c control points and output dimension o. In the latter case, the control points are expanded to shape [s, 1, c]. If None, control points are initialized to learnable parameters with zero initial value. Default is None.

  • aggregate_output (str) – If None, the output of each univariate spline is returned separately, resulting in an output of shape [batch, s, o], where s is the number of univariate splines and o is the output dimension of each univariate spline. If set to "mean" or "sum", the output is aggregated accordingly across the last dimension, resulting in an output of shape [batch, s]. Default is None.

Raises:
  • AssertionError – If order is not a positive integer.

  • ValueError – If knots is neither a torch.Tensor nor a dictionary, when provided.

  • ValueError – If aggregate_output is not None, “mean”, or “sum”.

  • ValueError – If control_points is not a torch.Tensor, when provided.

  • ValueError – If both knots and control_points are None.

  • ValueError – If knots is not two-dimensional, after processing.

  • ValueError – If control_points, after expansion when two-dimensional, is not three-dimensional.

  • ValueError – If, for each univariate spline, the number of knots is not equal to the sum of order and the number of control_points.

  • UserWarning – If, for each univariate spline, the number of control_points is lower than the order, resulting in a degenerate spline.

  • ValueError – If the number of univariate splines in knots and control_points do not match.

basis(x, collection=False)[source]#

Evaluate the B-spline basis functions for each univariate spline.

This method applies the Cox-de Boor recursion in vectorized form across all univariate splines of the vector spline.

Parameters:
  • x (torch.Tensor) – The points to be evaluated.

  • collection (bool) – If True, returns a list of basis functions for all orders up to the spline order. Default is False.

Raises:
  • ValueError – If collection is not a boolean.

  • ValueError – If x is not two-dimensional.

  • ValueError – If the number of input features does not match the number of univariate splines.

Returns:

The basis functions evaluated at x.

Return type:

torch.Tensor

forward(x)[source]#

Forward pass for the VectorizedSpline model. Each univariate spline is evaluated independently on its corresponding input feature.

The input is expected to have shape [batch, s], where s is the number of univariate splines. The output has shape [batch, s, o], where o is the output dimension of each univariate spline, unless an aggregation method is specified. If both s and o are 1, the output is aggregated across the last dimension, resulting in an output of shape [batch, s]. If aggregate_output is set to "mean" or "sum", the output is aggregated across the last dimension, resulting in an output of shape [batch, s].

Parameters:

x (torch.Tensor | LabelTensor) – The input tensor.

Returns:

The output tensor.

Return type:

torch.Tensor

derivative(x, degree)[source]#

Compute the degree-th derivative of each univariate spline at the given input points.

The output has shape [batch, s, o], where o is the output dimension of each univariate spline, unless an aggregation method is specified. If both s and o are 1, the output is aggregated across the last dimension, resulting in an output of shape [batch, s]. If aggregate_output is set to "mean" or "sum", the output is aggregated across the last dimension, resulting in an output of shape [batch, s].

Parameters:
Returns:

The derivative tensor.

Return type:

torch.Tensor

property control_points#

The control points of the spline.

Returns:

The control points.

Return type:

torch.Tensor

property knots#

The knots of the spline.

Returns:

The knots.

Return type:

torch.Tensor