DataModule#
- class Collator(max_conditions_lengths, automatic_batching, dataset=None)[source]#
Bases:
object
This callable class is used to collate the data points fetched from the dataset. The collation is performed based on the type of dataset used and on the batching strategy.
Initialize the object, setting the collate function based on whether automatic batching is enabled or not.
- Parameters:
max_conditions_lengths (dict) –
dict
containing the maximum number of data points to consider in a single batch for each condition.automatic_batching (bool) – Whether automatic PyTorch batching is enabled or not. For more information, see the
PinaDataModule
class.dataset (PinaDataset) – The dataset where the data is stored.
- class PinaDataModule(problem, train_size=0.7, test_size=0.2, val_size=0.1, batch_size=None, shuffle=True, repeat=False, automatic_batching=None, num_workers=0, pin_memory=False)[source]#
Bases:
LightningDataModule
This class extends
LightningDataModule
, allowing proper creation and management of different types of datasets defined in PINA.Initialize the object and creating datasets based on the input problem.
- Parameters:
problem (AbstractProblem) – The problem containing the data on which to create the datasets and dataloaders.
train_size (float) – Fraction of elements in the training split. It must be in the range [0, 1].
test_size (float) – Fraction of elements in the test split. It must be in the range [0, 1].
val_size (float) – Fraction of elements in the validation split. It must be in the range [0, 1].
batch_size (int) – The batch size used for training. If
None
, the entire dataset is returned in a single batch. Default isNone
.shuffle (bool) – Whether to shuffle the dataset before splitting. Default
True
.repeat (bool) – If
True
, in case of batch size larger than the number of elements in a specific condition, the elements are repeated until the batch size is reached. IfFalse
, the number of elements in the batch is the minimum between the batch size and the number of elements in the condition. Default isFalse
.automatic_batching – If
True
, automatic PyTorch batching is performed, which consists of extracting one element at a time from the dataset and collating them into a batch. This is useful when the dataset is too large to fit into memory. On the other hand, ifFalse
, the items are retrieved from the dataset all at once avoind the overhead of collating them into a batch and reducing the__getitem__
calls to the dataset. This is useful when the dataset fits into memory. Avoid using automatic batching whenbatch_size
is large. Default isFalse
.num_workers (int) – Number of worker threads for data loading. Default
0
(serial loading).pin_memory (bool) – Whether to use pinned memory for faster data transfer to GPU. Default
False
.
- Raises:
ValueError – If at least one of the splits is negative.
ValueError – If the sum of the splits is different from 1.
See also
For more information on multi-process data loading, see: https://pytorch.org/docs/stable/data.html#multi-process-data-loading
For details on memory pinning, see: https://pytorch.org/docs/stable/data.html#memory-pinning
- setup(stage=None)[source]#
Create the dataset objects for the given stage. If the stage is “fit”, the training and validation datasets are created. If the stage is “test”, the testing dataset is created.
- Parameters:
stage (str) – The stage for which to perform the dataset setup.
- Raises:
ValueError – If the stage is neither “fit” nor “test”.
- val_dataloader()[source]#
Create the validation dataloader.
- Returns:
The validation dataloader
- Return type:
- train_dataloader()[source]#
Create the training dataloader
- Returns:
The training dataloader
- Return type:
- class PinaSampler(dataset, shuffle)[source]#
Bases:
object
This class is used to create the sampler instance based on the shuffle parameter and the environment in which the code is running.
Instantiate and initialize the sampler.
- Parameters:
dataset (PinaDataset) – The dataset from which to sample.
shuffle (bool) – Whether to shuffle the dataset.
- Returns:
The sampler instance.
- Return type: