Source code for pygem.custom_deformation

"""
Module for a custom deformation.
"""
import numpy as np

from pygem import Deformation

[docs]class CustomDeformation(Deformation): """ Class to perform a custom deformation to the mesh points. :param callable func: the function definying the deformation of the input points. This function should take as input: *i*) a 2D array of shape (*n_points*, *3*) in which the points are arranged by row, or *ii*) an iterable object with 3 components. In this last case, computation of deformation is not vectorized and the overall cost may become heavy. :Example: >>> from pygem import CustomDeformation >>> import numpy as np >>> def move(x): >>> return x + x**2 >>> deform = CustomDeformation(move) >>> original_mesh_points = np.load( >>> 'tests/test_datasets/meshpoints_sphere_orig.npy') >>> new_mesh_points = deform(original_mesh_points) >>> # Deformation with non-vectorized function >>> def move(x): >>> x0, x1, x2 = x >>> return [x0**2, x1, x2] >>> deform = CustomDeformation(move) >>> new_mesh_points = deform(original_mesh_points) """ def __init__(self, func): self.__func = func
[docs] def __call__(self, src_pts): """ This method performs the deformation on the input points. :param numpy.ndarray src_pts: the array of dimensions (*n_points*, *3*) containing the points to deform. The points have to be arranged by row. :return: the deformed points :rtype: numpy.ndarray (with shape = (*n_points*, *3*)) """ try: return self.__func(src_pts) except: return np.array([self.__func(pt) for pt in src_pts])