Source code for pina._src.data.aggregator

"""Utility class for aggregating multiple dataloaders into a single iterable."""


[docs] class _Aggregator: """ Aggregate multiple dataloaders into a unified iterable object. The aggregator combines batches produced by multiple dataloaders according to the selected batching strategy. It is primarily used to coordinate the iteration of multiple training conditions within a single training loop. """ def __init__(self, dataloaders, batching_mode): """ Initialization of the :class:`_Aggregator` class. :param dict[str, DataLoader] dataloaders: The mapping between condition names and their corresponding dataloaders. :param str batching_mode: The strategy used to aggregate batches across dataloaders. Available options are ``"common_batch_size"`` for uniform batch sizes across conditions, ``"proportional"`` for batch sizes proportional to dataset sizes, and ``"separate_conditions"`` for iterating through each condition separately. :raises NotImplementedError: If the selected batching mode is not yet implemented. """ # Raise not implemented error for separate_conditions mode if batching_mode == "separate_conditions": raise NotImplementedError( "Batching mode 'separate_conditions' is not implemented yet." ) # Initialize attributes self.dataloaders = dataloaders self.batching_mode = batching_mode def __len__(self): """ Return the length of the aggregated dataloader. The length is determined by the number of iterations required to exhaust the dataloaders based on the selected batching mode. For ``"separate_conditions"``, the total number of iterations is the sum of the lengths of all dataloaders. For all other batching modes, the length corresponds to the maximum length among the aggregated dataloaders. :return: The length of the aggregated dataloader. :rtype: int """ # Separate conditions case if self.batching_mode == "separate_conditions": return sum(len(dl) for dl in self.dataloaders.values()) return max(len(dl) for dl in self.dataloaders.values()) def __iter__(self): """ Iterate over the aggregated dataloaders. At each iteration, a dictionary containing one batch per dataloader is yielded. If a dataloader is exhausted before the others, its iterator is restarted automatically to ensure continuous batch generation. :yield: The dictionary mapping each condition name to its batch. :rtype: Iterator[dict[str, Any]] """ # Initialize iterators for each dataloader iterators = {name: iter(dl) for name, dl in self.dataloaders.items()} # Iterate until the maximum number of iterations is reached for _ in range(len(self)): batch = {} # Generate a batch for each dataloader for name, dataloader in self.dataloaders.items(): # Attempt to get the next batch from the dataloader's iterator try: batch[name] = next(iterators[name]) # Restart the iterator if it is exhausted except StopIteration: iterators[name] = iter(dataloader) batch[name] = next(iterators[name]) yield batch