"""Vectorized univariate B-spline model with per-spline knots."""
import warnings
import torch
from pina._src.core.utils import check_consistency, check_positive_integer
[docs]
class VectorizedSpline(torch.nn.Module):
r"""
The vectorized B-spline model class.
A :class:`VectorizedSpline` represents a vector spline, i.e., a collection
of independent univariate B-splines evaluated in parallel. Each univariate
spline has its own knot vector and its own control points, and acts on one
input feature.
Given ``s`` univariate splines, the vector spline maps an input
:math:`x = (x^{(1)}, \dots, x^{(s)}) \in \mathbb{R}^s` to an output obtained
by evaluating each univariate spline on its corresponding scalar input
:math:`x^{(j)}`.
For the :math:`j`-th univariate spline of order :math:`k`, the output is
defined as
.. math::
S^{(j)}(x^{(j)}) = \sum_{i=1}^{n_j} B_{i,k}^{(j)}(x^{(j)}) C_i^{(j)},
where:
- :math:`C^{(j)}` are the control points of the :math:`j`-th univariate
spline. In the scalar-output case, :math:`C^{(j)} \in \mathbb{R}^{n_j}`.
More generally, each univariate spline may have output dimension
:math:`o`, so :math:`C^{(j)} \in \mathbb{R}^{o \times n_j}`.
- :math:`B_{i,k}^{(j)}(x)` are the B-spline basis functions of order
:math:`k`, i.e., piecewise polynomials of degree :math:`k-1`, associated
with the knot vector of the :math:`j`-th univariate spline.
- :math:`X^{(j)} = \{x_1^{(j)}, x_2^{(j)}, \dots, x_{m_j}^{(j)}\}` is the
non-decreasing knot vector of the :math:`j`-th univariate spline.
If the first and last knots of a given univariate spline are repeated
:math:`k` times, then that univariate spline interpolates its first and last
control points.
The full vector spline evaluates all univariate splines in parallel. If each
univariate spline has output dimension :math:`o`, then before optional
aggregation the output has shape ``[batch, s, o]``.
.. note::
Each univariate spline is forced to be zero outside the interval defined
by the first and last knots of its own knot vector.
.. note::
This class does not represent a single multivariate spline
:math:`\mathbb{R}^s \to \mathbb{R}^o` with a genuinely multivariate
basis. Instead, it represents a vector of splines built from ``s``
independent univariate splines, one for each input feature.
.. note::
When using the :meth:`derivative` method of this class, derivatives are
computed directly in vectorized form and returned with the correct
shape. In contrast, when relying on ``autograd``, derivatives must be
computed separately for each output dimension of each univariate spline
and then combined, since autograd does not natively handle this
vectorized structure.
:Example:
>>> from pina.model import VectorizedSpline
>>> import torch
>>> knt1 = torch.tensor([
... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
... [0.0, 0.0, 0.0, 0.5, 1.0, 1.0, 1.0],
... ])
>>> spline1 = VectorizedSpline(order=3, knots=knt1, control_points=None)
>>> knt2 = {"n": 7, "min": 0.0, "max": 2.0, "mode": "auto", "n_splines": 2}
>>> spline2 = VectorizedSpline(order=3, knots=knt2, control_points=None)
>>> knt3 = torch.tensor([
... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
... [0.0, 0.0, 0.0, 1.0, 2.0, 2.0, 2.0],
... ])
>>> ctrl3 = torch.tensor([
... [0.0, 1.0, 3.0, 2.0],
... [1.0, 0.0, 2.0, 1.0],
... ])
>>> spline3 = VectorizedSpline(order=3, knots=knt3, control_points=ctrl3)
"""
def __init__(
self,
order=4,
knots=None,
control_points=None,
aggregate_output=None,
):
"""
Initialization of the :class:`VectorizedSpline` class.
:param int order: The order of each univariate spline. The corresponding
basis functions are polynomials of degree ``order - 1``.
Default is 4.
:param knots: The knots of the spline. If a tensor is provided, it must
have shape ``[s, n]``, where ``s`` is the number of univariate
splines and ``n`` is the number of knots per univariate spline. If a
dictionary is provided, it must contain the keys ``"n"``, ``"min"``,
``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies
the number of knots for each univariate spline, ``"min"`` and
``"max"`` define the interval, ``"mode"`` selects the sampling
strategy, and ``"n_splines"`` specifies the number of univariate
splines. The supported modes are ``"uniform"``, where the knots are
evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots
are constructed to ensure that each univariate spline interpolates
the first and last control points. In this case, the number of knots
is adjusted if :math:`n < 2 * order`. If None is given, knots are
initialized automatically over :math:`[0, 1]` ensuring interpolation
of the first and last control points. Default is None.
:type knots: torch.Tensor | dict
:param torch.Tensor control_points: The control points tensor. The
tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where
each univariate spline has ``c`` control points and output dimension
``o``. In the latter case, the control points are expanded to shape
``[s, 1, c]``. If None, control points are initialized to learnable
parameters with zero initial value. Default is None.
:param str aggregate_output: If None, the output of each univariate
spline is returned separately, resulting in an output of shape
``[batch, s, o]``, where ``s`` is the number of univariate splines
and ``o`` is the output dimension of each univariate spline. If set
to ``"mean"`` or ``"sum"``, the output is aggregated accordingly
across the last dimension, resulting in an output of shape
``[batch, s]``. Default is None.
:raises AssertionError: If ``order`` is not a positive integer.
:raises ValueError: If ``knots`` is neither a torch.Tensor nor a
dictionary, when provided.
:raises ValueError: If ``aggregate_output`` is not None, "mean", or
"sum".
:raises ValueError: If ``control_points`` is not a torch.Tensor,
when provided.
:raises ValueError: If both ``knots`` and ``control_points`` are None.
:raises ValueError: If ``knots`` is not two-dimensional, after
processing.
:raises ValueError: If ``control_points``, after expansion when
two-dimensional, is not three-dimensional.
:raises ValueError: If, for each univariate spline, the number of
``knots`` is not equal to the sum of ``order`` and the number of
``control_points.``
:raises UserWarning: If, for each univariate spline, the number of
``control_points`` is lower than the ``order``, resulting in a
degenerate spline.
:raises ValueError: If the number of univariate splines in ``knots`` and
``control_points`` do not match.
"""
super().__init__()
# Check consistency
check_positive_integer(value=order, strict=True)
check_consistency(knots, (type(None), torch.Tensor, dict))
check_consistency(control_points, (type(None), torch.Tensor))
# Raise error if neither knots nor control points are provided
if knots is None and control_points is None:
raise ValueError("knots and control_points cannot both be None.")
# Raise error if aggregate_output is not None, "mean", or "sum"
if aggregate_output not in (None, "mean", "sum"):
raise ValueError(
f"aggregate_output must be None, 'mean', or 'sum'."
f" Got {aggregate_output}."
)
# Initialize knots if not provided
if knots is None and control_points is not None:
knots = {
"n": control_points.shape[-1] + order,
"min": 0,
"max": 1,
"n_splines": control_points.shape[0],
"mode": "auto",
}
# Initialization - knots and control points managed by their setters
self.order = order
self.knots = knots
self.control_points = control_points
self.aggregate_output = aggregate_output
# Check dimensionality of control points
if self.control_points.ndim != 3:
raise ValueError("control_points must be three-dimensional.")
# Raise error if #knots != order + #control_points
if self.knots.shape[-1] != self.order + self.control_points.shape[-1]:
raise ValueError(
f" The number of knots per spline must be equal to order + the"
f" number of control points. Got {self.knots.shape[-1]} knots"
f" per spline, {self.control_points.shape[-1]} control points,"
f" and {self.order} order."
)
# Raise warning if spline is degenerate
if self.control_points.shape[-1] < self.order:
warnings.warn(
"The number of control points per spline is smaller than the"
" spline order. This creates a degenerate spline with limited"
" flexibility.",
UserWarning,
)
# Raise error if knots and control points have different # of splines
if self.knots.shape[0] != self.control_points.shape[0]:
raise ValueError(
f"The number of splines must be the same for knots and"
f" control points. Got {self.knots.shape[0]} splines for knots"
f" and {self.control_points.shape[0]} splines for control"
f" points."
)
# Precompute boundary interval index
self.register_buffer(
"_boundary_interval_idx", self._compute_boundary_interval()
)
# Precompute denominators used in derivative formulas
self._compute_derivative_denominators()
def _compute_boundary_interval(self):
"""
Precompute the index of the rightmost non-degenerate interval to improve
performance, eliminating the need to perform a search loop in the basis
function on each call.
:return: The index of the rightmost non-degenerate interval for each
univariate spline.
:rtype: torch.Tensor
"""
# Compute the differences between consecutive knots for each spline
diffs = self._knots[:, 1:] - self._knots[:, :-1]
valid = diffs > 0
# Initialize idx tensor to store the last valid interval for each spline
idx = torch.zeros(
self._knots.shape[0], dtype=torch.long, device=self._knots.device
)
# For each spline, find the last idx where interval is non-degenerate
for s in range(self._knots.shape[0]):
valid_s = torch.nonzero(valid[s], as_tuple=False)
idx[s] = valid_s[-1, 0] if valid_s.numel() > 0 else 0
return idx
def _compute_derivative_denominators(self):
"""
Precompute the denominators used in the derivatives for all orders up to
the spline order to avoid redundant calculations.
"""
# Precompute for order 2 to k
for i in range(2, self.order + 1):
# Denominators for the derivative recurrence relations
left_den = self.knots[:, i - 1 : -1] - self.knots[:, :-i]
right_den = self.knots[:, i:] - self.knots[:, 1 : -i + 1]
# If consecutive knots are equal, set left and right factors to zero
left_fac = torch.where(
torch.abs(left_den) > 1e-10,
(i - 1) / left_den,
torch.zeros_like(left_den),
)
right_fac = torch.where(
torch.abs(right_den) > 1e-10,
(i - 1) / right_den,
torch.zeros_like(right_den),
)
# Register buffers
self.register_buffer(f"_left_factor_order_{i}", left_fac)
self.register_buffer(f"_right_factor_order_{i}", right_fac)
[docs]
def basis(self, x, collection=False):
"""
Evaluate the B-spline basis functions for each univariate spline.
This method applies the Cox-de Boor recursion in vectorized form across
all univariate splines of the vector spline.
:param torch.Tensor x: The points to be evaluated.
:param bool collection: If True, returns a list of basis functions for
all orders up to the spline order. Default is False.
:raises ValueError: If ``collection`` is not a boolean.
:raises ValueError: If ``x`` is not two-dimensional.
:raises ValueError: If the number of input features does not match
the number of univariate splines.
:return: The basis functions evaluated at x.
:rtype: torch.Tensor
"""
# Check consistency
check_consistency(collection, bool)
# Ensure x is a tensor of the same dtype as knots
x = x.as_subclass(torch.Tensor).to(dtype=self.knots.dtype)
# Raise error if x does not have shape (batch, s)
if x.ndim != 2:
raise ValueError(
f"The input must have shape (batch, s). Got {x.shape}."
)
# Raise error if x has different number of splines than knots
if x.shape[1] != self.knots.shape[0]:
raise ValueError(
f"The number of input features must be the same as the number"
f" of univariate splines. Got {x.shape[1]} input features,"
f" but {self.knots.shape[0]} univariate splines."
)
# Add a final dimension to x for broadcasting
x = x.unsqueeze(-1)
# Add an initial dimension to knots for broadcasting
knots = self.knots.unsqueeze(0)
# Base case of recursion: indicator functions for the intervals
basis = (x >= knots[..., :-1]) & (x < knots[..., 1:])
basis = basis.to(x.dtype)
# Extract left and right knots of the boundary interval for each spline
range_tensor = torch.arange(self.knots.shape[0], device=x.device)
knot_left = self.knots[range_tensor, self._boundary_interval_idx]
knot_right = self.knots[range_tensor, self._boundary_interval_idx + 1]
# Identify points at the rightmost boundary
at_rightmost_boundary = (
x.squeeze(-1) >= knot_left.unsqueeze(0)
) & torch.isclose(
x.squeeze(-1), knot_right.unsqueeze(0), rtol=1e-8, atol=1e-10
)
# Ensure the correct value is set at the rightmost boundary
if torch.any(at_rightmost_boundary):
b_idx, s_idx = torch.nonzero(at_rightmost_boundary, as_tuple=True)
basis[b_idx, s_idx, self._boundary_interval_idx[s_idx]] = 1.0
# If returning the whole collection, initialize list
if collection:
basis_collection = [None, basis]
# Cox-de Boor recursion -- iterative case
for i in range(1, self.order):
# Compute the denominators for both terms of the recursion
denom1 = knots[..., i:-1] - knots[..., : -(i + 1)]
denom2 = knots[..., i + 1 :] - knots[..., 1:-i]
# Ensure no division by zero
denom1 = torch.where(
denom1.abs() < 1e-8, torch.ones_like(denom1), denom1
)
denom2 = torch.where(
denom2.abs() < 1e-8, torch.ones_like(denom2), denom2
)
# Compute the two terms of the recursion
term1 = ((x - knots[..., : -(i + 1)]) / denom1) * basis[..., :-1]
term2 = ((knots[..., i + 1 :] - x) / denom2) * basis[..., 1:]
# Combine terms to get the new basis
basis = term1 + term2
if collection:
basis_collection.append(basis)
return basis_collection if collection else basis
[docs]
def forward(self, x):
"""
Forward pass for the :class:`VectorizedSpline` model. Each univariate
spline is evaluated independently on its corresponding input feature.
The input is expected to have shape ``[batch, s]``, where ``s`` is the
number of univariate splines. The output has shape ``[batch, s, o]``,
where ``o`` is the output dimension of each univariate spline, unless an
aggregation method is specified. If both ``s`` and ``o`` are 1, the
output is aggregated across the last dimension, resulting in an output
of shape ``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
``"sum"``, the output is aggregated across the last dimension, resulting
in an output of shape ``[batch, s]``.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor
"""
# Compute the basis functions at x
basis = self.basis(x)
# Compute the output for each spline
out = torch.einsum("bsc,soc->bso", basis, self.control_points)
# Aggregate output if needed
if self.aggregate_output == "mean":
out = out.mean(dim=-1)
elif self.aggregate_output == "sum":
out = out.sum(dim=-1)
elif out.shape[1] == 1 and out.shape[2] == 1:
out = out.squeeze(-1)
return out
[docs]
def derivative(self, x, degree):
"""
Compute the ``degree``-th derivative of each univariate spline at the
given input points.
The output has shape ``[batch, s, o]``, where ``o`` is the output
dimension of each univariate spline, unless an aggregation method is
specified. If both ``s`` and ``o`` are 1, the output is aggregated
across the last dimension, resulting in an output of shape
``[batch, s]``. If ``aggregate_output`` is set to ``"mean"`` or
``"sum"``, the output is aggregated across the last dimension, resulting
in an output of shape ``[batch, s]``.
:param x: The input tensor.
:type x: torch.Tensor | LabelTensor
:param int degree: The derivative degree to compute.
:return: The derivative tensor.
:rtype: torch.Tensor
"""
# Check consistency
check_positive_integer(degree, strict=False)
# Compute basis derivative
der = self._basis_derivative(x.as_subclass(torch.Tensor), degree=degree)
# Compute the output for each spline
out = torch.einsum("bsc,soc->bso", der, self.control_points)
# Aggregate output if needed
if self.aggregate_output == "mean":
out = out.mean(dim=-1)
elif self.aggregate_output == "sum":
out = out.sum(dim=-1)
elif out.shape[1] == 1 and out.shape[2] == 1:
out = out.squeeze(-1)
return out
def _basis_derivative(self, x, degree):
"""
Compute the ``degree``-th derivative of the vectorized spline basis
functions at the given input points using an iterative approach.
:param torch.Tensor x: The points to be evaluated.
:param int degree: The derivative degree to compute.
:return: The derivative of the basis functions of order ``self.order``.
:rtype: torch.Tensor
"""
# Compute the whole basis collection
basis = self.basis(x, collection=True)
# Derivatives initialization (dummy at index 0 for convenience)
derivatives = [None] + [basis[o] for o in range(1, self.order + 1)]
# Iterate over derivative degrees
for _ in range(1, degree + 1):
# Current degree derivatives (with dummy at index 0 for convenience)
current_der = [None] * (self.order + 1)
current_der[1] = torch.zeros_like(derivatives[1])
# Iterate over basis orders
for o in range(2, self.order + 1):
# Retrieve precomputed factors
left_fac = getattr(self, f"_left_factor_order_{o}")
right_fac = getattr(self, f"_right_factor_order_{o}")
# derivatives[o - 1] has shape [b, s, m]
# Slice previous derivatives to align
left_part = derivatives[o - 1][..., :-1]
right_part = derivatives[o - 1][..., 1:]
# Broadcast factors over batch dims
left_fac = left_fac.unsqueeze(0)
right_fac = right_fac.unsqueeze(0)
# Compute current derivatives
current_der[o] = left_fac * left_part - right_fac * right_part
# Update derivatives for next degree
derivatives = current_der
return derivatives[self.order]
@property
def control_points(self):
"""
The control points of the spline.
:return: The control points.
:rtype: torch.Tensor
"""
return self._control_points
@control_points.setter
def control_points(self, control_points):
"""
Set the control points of the spline.
:param torch.Tensor control_points: The control points tensor. The
tensor must be either of shape ``[s, o, c]`` or ``[s, c]``, where
each univariate spline has ``c`` control points and output dimension
``o``. In the latter case, the control points are expanded to shape
``[s, 1, c]``.
:raises ValueError: If there are not enough knots to define the control
points, due to the relation: #knots = order + #control_points.
"""
# If control points are not provided, initialize them
if control_points is None:
# Check that there are enough knots to define control points
if self.knots.shape[-1] < self.order + 1:
raise ValueError(
f"Not enough knots to define control points. Got"
f" {self.knots.shape[-1]} knots for each univariate spline,"
f" but need at least {self.order + 1}."
)
# Initialize control points to zero
control_points = torch.zeros(
self.knots.shape[0], 1, self.knots.shape[-1] - self.order
)
# If a the control points are 2D, add an output dimension of size 1
if control_points.ndim == 2:
control_points = control_points.unsqueeze(1)
# Set control points
self._control_points = torch.nn.Parameter(
control_points, requires_grad=True
)
@property
def knots(self):
"""
The knots of the spline.
:return: The knots.
:rtype: torch.Tensor
"""
return self._knots
@knots.setter
def knots(self, value):
"""
Set the knots of the spline.
:param value: The knots of the spline. If a tensor is provided, it must
have shape ``[s, n]``, where ``s`` is the number of univariate
splines and ``n`` is the number of knots per univariate spline. If a
dictionary is provided, it must contain the keys ``"n"``, ``"min"``,
``"max"``, ``"mode"``, and ``"n_splines"``. Here, ``"n"`` specifies
the number of knots for each univariate spline, ``"min"`` and
``"max"`` define the interval, ``"mode"`` selects the sampling
strategy, and ``"n_splines"`` specifies the number of univariate
splines. The supported modes are ``"uniform"``, where the knots are
evenly spaced over :math:`[min, max]`, and ``"auto"``, where knots
are constructed to ensure that each univariate spline interpolates
the first and last control points. In this case, the number of knots
is adjusted if :math:`n < 2 * order`. If None is given, knots are
initialized automatically over :math:`[0, 1]` ensuring interpolation
of the first and last control points.
:type value: torch.Tensor | dict
:raises ValueError: If a dictionary is provided but does not contain
the required keys.
:raises ValueError: If the mode specified in the dictionary is invalid.
:raises ValueError: If knots is not two-dimensional after processing.
"""
# If a dictionary is provided, initialize knots accordingly
if isinstance(value, dict):
# Check that required keys are present
required_keys = {"n", "min", "max", "mode", "n_splines"}
if not required_keys.issubset(value.keys()):
raise ValueError(
f"When providing knots as a dictionary, the following "
f"keys must be present: {required_keys}. Got "
f"{value.keys()}."
)
# Save number of splines for later use
n_splines = value["n_splines"]
# Uniform sampling of knots
if value["mode"] == "uniform":
value = torch.linspace(value["min"], value["max"], value["n"])
# Automatic sampling of interpolating knots
elif value["mode"] == "auto":
# Repeat the first and last knots 'order' times
initial_knots = torch.ones(self.order) * value["min"]
final_knots = torch.ones(self.order) * value["max"]
# Number of internal knots
n_internal = value["n"] - 2 * self.order
# If no internal knots are needed, just concatenate boundaries
if n_internal <= 0:
value = torch.cat((initial_knots, final_knots))
# Else, sample internal knots uniformly and exclude boundaries
# Recover the correct number of internal knots when slicing by
# adding 2 to n_internal
else:
internal_knots = torch.linspace(
value["min"], value["max"], n_internal + 2
)[1:-1]
value = torch.cat(
(initial_knots, internal_knots, final_knots)
)
# Raise error if mode is invalid
else:
raise ValueError(
f"Invalid mode for knots initialization. Got "
f"{value['mode']}, but expected 'uniform' or 'auto'."
)
# Repeat the knot vector for each spline
value = value.unsqueeze(0).repeat(n_splines, 1)
# Set knots
self.register_buffer("_knots", value.sort(dim=-1).values)
# Check dimensionality of knots
if self.knots.ndim != 2:
raise ValueError("knots must be two-dimensional.")
# Recompute boundary interval when knots change
if hasattr(self, "_boundary_interval_idx"):
self.register_buffer(
"_boundary_interval_idx", self._compute_boundary_interval()
)
# Recompute derivative denominators when knots change
self._compute_derivative_denominators()