""" Module for LabelTensor """
from copy import deepcopy
import torch
from torch import Tensor
[docs]
class LabelTensor(torch.Tensor):
"""Torch tensor with a label for any column."""
@staticmethod
def __new__(cls, x, labels, *args, **kwargs):
return super().__new__(cls, x, *args, **kwargs)
def __init__(self, x, labels):
"""
Construct a `LabelTensor` by passing a tensor and a list of column
labels. Such labels uniquely identify the columns of the tensor,
allowing for an easier manipulation.
:param torch.Tensor x: The data tensor.
:param labels: The labels of the columns.
:type labels: str | list(str) | tuple(str)
:Example:
>>> from pina import LabelTensor
>>> tensor = LabelTensor(torch.rand((2000, 3)), ['a', 'b', 'c'])
>>> tensor
tensor([[6.7116e-02, 4.8892e-01, 8.9452e-01],
[9.2392e-01, 8.2065e-01, 4.1986e-04],
[8.9266e-01, 5.5446e-01, 6.3500e-01],
...,
[5.8194e-01, 9.4268e-01, 4.1841e-01],
[1.0246e-01, 9.5179e-01, 3.7043e-02],
[9.6150e-01, 8.0656e-01, 8.3824e-01]])
>>> tensor.extract('a')
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor['a']
tensor([[0.0671],
[0.9239],
[0.8927],
...,
[0.5819],
[0.1025],
[0.9615]])
>>> tensor.extract(['a', 'b'])
tensor([[0.0671, 0.4889],
[0.9239, 0.8207],
[0.8927, 0.5545],
...,
[0.5819, 0.9427],
[0.1025, 0.9518],
[0.9615, 0.8066]])
>>> tensor.extract(['b', 'a'])
tensor([[0.4889, 0.0671],
[0.8207, 0.9239],
[0.5545, 0.8927],
...,
[0.9427, 0.5819],
[0.9518, 0.1025],
[0.8066, 0.9615]])
"""
if x.ndim == 1:
x = x.reshape(-1, 1)
if isinstance(labels, str):
labels = [labels]
if len(labels) != x.shape[-1]:
raise ValueError(
"the tensor has not the same number of columns of "
"the passed labels."
)
self._labels = labels
def __deepcopy__(self, __):
"""
Implements deepcopy for label tensor. By default it stores the
current labels and use the :meth:`~torch._tensor.Tensor.__deepcopy__`
method for creating a new :class:`pina.label_tensor.LabelTensor`.
:param __: Placeholder parameter.
:type __: None
:return: The deep copy of the :class:`pina.label_tensor.LabelTensor`.
:rtype: LabelTensor
"""
labels = self.labels
copy_tensor = deepcopy(self.tensor)
return LabelTensor(copy_tensor, labels)
@property
def labels(self):
"""Property decorator for labels
:return: labels of self
:rtype: list
"""
return self._labels
@labels.setter
def labels(self, labels):
if len(labels) != self.shape[self.ndim - 1]: # small check
raise ValueError(
"The tensor has not the same number of columns of "
"the passed labels."
)
self._labels = labels # assign the label
[docs]
@staticmethod
def vstack(label_tensors):
"""
Stack tensors vertically. For more details, see
:meth:`torch.vstack`.
:param list(LabelTensor) label_tensors: the tensors to stack. They need
to have equal labels.
:return: the stacked tensor
:rtype: LabelTensor
"""
if len(label_tensors) == 0:
return []
all_labels = [label for lt in label_tensors for label in lt.labels]
if set(all_labels) != set(label_tensors[0].labels):
raise RuntimeError("The tensors to stack have different labels")
labels = label_tensors[0].labels
tensors = [lt.extract(labels) for lt in label_tensors]
return LabelTensor(torch.vstack(tensors), labels)
[docs]
def clone(self, *args, **kwargs):
"""
Clone the LabelTensor. For more details, see
:meth:`torch.Tensor.clone`.
:return: A copy of the tensor.
:rtype: LabelTensor
"""
# # used before merging
# try:
# out = LabelTensor(super().clone(*args, **kwargs), self.labels)
# except:
# out = super().clone(*args, **kwargs)
out = LabelTensor(super().clone(*args, **kwargs), self.labels)
return out
[docs]
def to(self, *args, **kwargs):
"""
Performs Tensor dtype and/or device conversion. For more details, see
:meth:`torch.Tensor.to`.
"""
tmp = super().to(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
[docs]
def select(self, *args, **kwargs):
"""
Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
"""
tmp = super().select(*args, **kwargs)
tmp._labels = self._labels
return tmp
[docs]
def cuda(self, *args, **kwargs):
"""
Send Tensor to cuda. For more details, see :meth:`torch.Tensor.cuda`.
"""
tmp = super().cuda(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
[docs]
def cpu(self, *args, **kwargs):
"""
Send Tensor to cpu. For more details, see :meth:`torch.Tensor.cpu`.
"""
tmp = super().cpu(*args, **kwargs)
new = self.__class__.clone(self)
new.data = tmp.data
return new
[docs]
def detach(self):
detached = super().detach()
if hasattr(self, "_labels"):
detached._labels = self._labels
return detached
[docs]
def requires_grad_(self, mode=True):
lt = super().requires_grad_(mode)
lt.labels = self.labels
return lt
[docs]
def append(self, lt, mode="std"):
"""
Return a copy of the merged tensors.
:param LabelTensor lt: The tensor to merge.
:param str mode: {'std', 'first', 'cross'}
:return: The merged tensors.
:rtype: LabelTensor
"""
if set(self.labels).intersection(lt.labels):
raise RuntimeError("The tensors to merge have common labels")
new_labels = self.labels + lt.labels
if mode == "std":
new_tensor = torch.cat((self, lt), dim=1)
elif mode == "first":
raise NotImplementedError
elif mode == "cross":
tensor1 = self
tensor2 = lt
n1 = tensor1.shape[0]
n2 = tensor2.shape[0]
tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels)
tensor2 = LabelTensor(
tensor2.repeat_interleave(n1, dim=0), labels=tensor2.labels
)
new_tensor = torch.cat((tensor1, tensor2), dim=1)
new_tensor = new_tensor.as_subclass(LabelTensor)
new_tensor.labels = new_labels
return new_tensor
def __getitem__(self, index):
"""
Return a copy of the selected tensor.
"""
if isinstance(index, str) or (
isinstance(index, (tuple, list))
and all(isinstance(a, str) for a in index)
):
return self.extract(index)
selected_lt = super(Tensor, self).__getitem__(index)
try:
len_index = len(index)
except TypeError:
len_index = 1
if isinstance(index, int) or len_index == 1:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(1, -1)
if hasattr(self, "labels"):
selected_lt.labels = self.labels
elif len_index == 2:
if selected_lt.ndim == 1:
selected_lt = selected_lt.reshape(-1, 1)
if hasattr(self, "labels"):
if isinstance(index[1], list):
selected_lt.labels = [self.labels[i] for i in index[1]]
else:
selected_lt.labels = self.labels[index[1]]
else:
selected_lt.labels = self.labels
return selected_lt
@property
def tensor(self):
return self.as_subclass(Tensor)
def __len__(self) -> int:
return super().__len__()
def __str__(self):
if hasattr(self, "labels"):
s = f"labels({str(self.labels)})\n"
else:
s = "no labels\n"
s += super().__str__()
return s