Source code for pina.model.layers.rbf_layer

"""Module for Radial Basis Function Interpolation layer."""

import math
import warnings
from itertools import combinations_with_replacement
import torch
from ...utils import check_consistency


def linear(r):
    """
    Linear radial basis function.
    """
    return -r


def thin_plate_spline(r, eps=1e-7):
    """
    Thin plate spline radial basis function.
    """
    r = torch.clamp(r, min=eps)
    return r**2 * torch.log(r)


def cubic(r):
    """
    Cubic radial basis function.
    """
    return r**3


def quintic(r):
    """
    Quintic radial basis function.
    """
    return -(r**5)


def multiquadric(r):
    """
    Multiquadric radial basis function.
    """
    return -torch.sqrt(r**2 + 1)


def inverse_multiquadric(r):
    """
    Inverse multiquadric radial basis function.
    """
    return 1 / torch.sqrt(r**2 + 1)


def inverse_quadratic(r):
    """
    Inverse quadratic radial basis function.
    """
    return 1 / (r**2 + 1)


def gaussian(r):
    """
    Gaussian radial basis function.
    """
    return torch.exp(-(r**2))


radial_functions = {
    "linear": linear,
    "thin_plate_spline": thin_plate_spline,
    "cubic": cubic,
    "quintic": quintic,
    "multiquadric": multiquadric,
    "inverse_multiquadric": inverse_multiquadric,
    "inverse_quadratic": inverse_quadratic,
    "gaussian": gaussian,
}

scale_invariant = {"linear", "thin_plate_spline", "cubic", "quintic"}

min_degree_funcs = {
    "multiquadric": 0,
    "linear": 0,
    "thin_plate_spline": 1,
    "cubic": 1,
    "quintic": 2,
}


[docs] class RBFBlock(torch.nn.Module): """ Radial Basis Function (RBF) interpolation layer. It need to be fitted with the data with the method :meth:`fit`, before it can be used to interpolate new points. The layer is not trainable. .. note:: It reproduces the implementation of ``scipy.interpolate.RBFBlock`` and it is inspired from the implementation in `torchrbf. <https://github.com/ArmanMaesumi/torchrbf>`_ """ def __init__( self, neighbors=None, smoothing=0.0, kernel="thin_plate_spline", epsilon=None, degree=None, ): """ :param int neighbors: Number of neighbors to use for the interpolation. If ``None``, use all data points. :param float smoothing: Smoothing parameter for the interpolation. if 0.0, the interpolation is exact and no smoothing is applied. :param str kernel: Radial basis function to use. Must be one of ``linear``, ``thin_plate_spline``, ``cubic``, ``quintic``, ``multiquadric``, ``inverse_multiquadric``, ``inverse_quadratic``, or ``gaussian``. :param float epsilon: Shape parameter that scaled the input to the RBF. This defaults to 1 for kernels in ``scale_invariant`` dictionary, and must be specified for other kernels. :param int degree: Degree of the added polynomial. For some kernels, there exists a minimum degree of the polynomial such that the RBF is well-posed. Those minimum degrees are specified in the `min_degree_funcs` dictionary above. If `degree` is less than the minimum degree, a warning is raised and the degree is set to the minimum value. """ super().__init__() check_consistency(neighbors, (int, type(None))) check_consistency(smoothing, (int, float, torch.Tensor)) check_consistency(kernel, str) check_consistency(epsilon, (float, type(None))) check_consistency(degree, (int, type(None))) self.neighbors = neighbors self.smoothing = smoothing self.kernel = kernel self.epsilon = epsilon self.degree = degree self.powers = None # initialize data points and values self.y = None self.d = None # initialize attributes for the fitted model self._shift = None self._scale = None self._coeffs = None @property def smoothing(self): """ Smoothing parameter for the interpolation. :rtype: float """ return self._smoothing @smoothing.setter def smoothing(self, value): self._smoothing = value @property def kernel(self): """ Radial basis function to use. :rtype: str """ return self._kernel @kernel.setter def kernel(self, value): if value not in radial_functions: raise ValueError(f"Unknown kernel: {value}") self._kernel = value.lower() @property def epsilon(self): """ Shape parameter that scaled the input to the RBF. :rtype: float """ return self._epsilon @epsilon.setter def epsilon(self, value): if value is None: if self.kernel in scale_invariant: value = 1.0 else: raise ValueError("Must specify `epsilon` for this kernel.") else: value = float(value) self._epsilon = value @property def degree(self): """ Degree of the added polynomial. :rtype: int """ return self._degree @degree.setter def degree(self, value): min_degree = min_degree_funcs.get(self.kernel, -1) if value is None: value = max(min_degree, 0) else: value = int(value) if value < -1: raise ValueError("`degree` must be at least -1.") if value < min_degree: warnings.warn( "`degree` is too small for this kernel. Setting to " f"{min_degree}.", UserWarning, ) self._degree = value def _check_data(self, y, d): if y.ndim != 2: raise ValueError("y must be a 2-dimensional tensor.") if d.shape[0] != y.shape[0]: raise ValueError( "The first dim of d must have the same length as " "the first dim of y." ) if isinstance(self.smoothing, (int, float)): self.smoothing = ( torch.full((y.shape[0],), self.smoothing).float().to(y.device) )
[docs] def fit(self, y, d): """ Fit the RBF interpolator to the data. :param torch.Tensor y: (n, d) tensor of data points. :param torch.Tensor d: (n, m) tensor of data values. """ self._check_data(y, d) self.y = y self.d = d if self.neighbors is None: nobs = self.y.shape[0] else: raise NotImplementedError("neighbors currently not supported") powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to( y.device ) if powers.shape[0] > nobs: raise ValueError( "The data is not compatible with the requested degree." ) if self.neighbors is None: self._shift, self._scale, self._coeffs = RBFBlock.solve( self.y, self.d.reshape((self.y.shape[0], -1)), self.smoothing, self.kernel, self.epsilon, powers, ) self.powers = powers
[docs] def forward(self, x): """ Returns the interpolated data at the given points `x`. :param torch.Tensor x: `(n, d)` tensor of points at which to query the interpolator :rtype: `(n, m)` torch.Tensor of interpolated data. """ if x.ndim != 2: raise ValueError("`x` must be a 2-dimensional tensor.") nx, ndim = x.shape if ndim != self.y.shape[1]: raise ValueError( "Expected the second dim of `x` to have length " f"{self.y.shape[1]}." ) kernel_func = radial_functions[self.kernel] yeps = self.y * self.epsilon xeps = x * self.epsilon xhat = (x - self._shift) / self._scale kv = RBFBlock.kernel_vector(xeps, yeps, kernel_func) p = RBFBlock.polynomial_matrix(xhat, self.powers) vec = torch.cat([kv, p], dim=1) out = torch.matmul(vec, self._coeffs) out = out.reshape((nx,) + self.d.shape[1:]) return out
[docs] @staticmethod def kernel_vector(x, y, kernel_func): """ Evaluate radial functions with centers `y` for all points in `x`. :param torch.Tensor x: `(n, d)` tensor of points. :param torch.Tensor y: `(m, d)` tensor of centers. :param str kernel_func: Radial basis function to use. :rtype: `(n, m)` torch.Tensor of radial function values. """ return kernel_func(torch.cdist(x, y))
[docs] @staticmethod def polynomial_matrix(x, powers): """ Evaluate monomials at `x` with given `powers`. :param torch.Tensor x: `(n, d)` tensor of points. :param torch.Tensor powers: `(r, d)` tensor of powers for each monomial. :rtype: `(n, r)` torch.Tensor of monomial values. """ x_ = torch.repeat_interleave(x, repeats=powers.shape[0], dim=0) powers_ = powers.repeat(x.shape[0], 1) return torch.prod(x_**powers_, dim=1).view(x.shape[0], powers.shape[0])
[docs] @staticmethod def kernel_matrix(x, kernel_func): """ Returns radial function values for all pairs of points in `x`. :param torch.Tensor x: `(n, d`) tensor of points. :param str kernel_func: Radial basis function to use. :rtype: `(n, n`) torch.Tensor of radial function values. """ return kernel_func(torch.cdist(x, x))
[docs] @staticmethod def monomial_powers(ndim, degree): """ Return the powers for each monomial in a polynomial. :param int ndim: Number of variables in the polynomial. :param int degree: Degree of the polynomial. :rtype: `(nmonos, ndim)` torch.Tensor where each row contains the powers for each variable in a monomial. """ nmonos = math.comb(degree + ndim, ndim) out = torch.zeros((nmonos, ndim), dtype=torch.int32) count = 0 for deg in range(degree + 1): for mono in combinations_with_replacement(range(ndim), deg): for var in mono: out[count, var] += 1 count += 1 return out
[docs] @staticmethod def build(y, d, smoothing, kernel, epsilon, powers): """ Build the RBF linear system. :param torch.Tensor y: (n, d) tensor of data points. :param torch.Tensor d: (n, m) tensor of data values. :param torch.Tensor smoothing: (n,) tensor of smoothing parameters. :param str kernel: Radial basis function to use. :param float epsilon: Shape parameter that scaled the input to the RBF. :param torch.Tensor powers: (r, d) tensor of powers for each monomial. :rtype: (lhs, rhs, shift, scale) where `lhs` and `rhs` are the left-hand side and right-hand side of the linear system, and `shift` and `scale` are the shift and scale parameters. """ p = d.shape[0] s = d.shape[1] r = powers.shape[0] kernel_func = radial_functions[kernel] mins = torch.min(y, dim=0).values maxs = torch.max(y, dim=0).values shift = (maxs + mins) / 2 scale = (maxs - mins) / 2 scale[scale == 0.0] = 1.0 yeps = y * epsilon yhat = (y - shift) / scale lhs = torch.empty((p + r, p + r), device=d.device).float() lhs[:p, :p] = RBFBlock.kernel_matrix(yeps, kernel_func) lhs[:p, p:] = RBFBlock.polynomial_matrix(yhat, powers) lhs[p:, :p] = lhs[:p, p:].T lhs[p:, p:] = 0.0 lhs[:p, :p] += torch.diag(smoothing) rhs = torch.empty((r + p, s), device=d.device).float() rhs[:p] = d rhs[p:] = 0.0 return lhs, rhs, shift, scale
[docs] @staticmethod def solve(y, d, smoothing, kernel, epsilon, powers): """ Build then solve the RBF linear system. :param torch.Tensor y: (n, d) tensor of data points. :param torch.Tensor d: (n, m) tensor of data values. :param torch.Tensor smoothing: (n,) tensor of smoothing parameters. :param str kernel: Radial basis function to use. :param float epsilon: Shape parameter that scaled the input to the RBF. :param torch.Tensor powers: (r, d) tensor of powers for each monomial. :raises ValueError: If the linear system is singular. :rtype: (shift, scale, coeffs) where `shift` and `scale` are the shift and scale parameters, and `coeffs` are the coefficients of the interpolator """ lhs, rhs, shift, scale = RBFBlock.build( y, d, smoothing, kernel, epsilon, powers ) try: coeffs = torch.linalg.solve(lhs, rhs) except RuntimeError as e: msg = "Singular matrix." nmonos = powers.shape[0] if nmonos > 0: pmat = RBFBlock.polynomial_matrix((y - shift) / scale, powers) rank = torch.linalg.matrix_rank(pmat) if rank < nmonos: msg = ( "Singular matrix. The matrix of monomials evaluated at " "the data point coordinates does not have full column " f"rank ({rank}/{nmonos})." ) raise ValueError(msg) from e return shift, scale, coeffs