[docs]@dataclassclassBaseDataModuleConfig:"""Holds :class:`BaseDataModule` config values. Args: data_dir: See :paramref:`~.BaseSubtaskConfig.data_dir`. device: See :paramref:`~.OptimizationSubtaskConfig.device`. max_per_device_batch_size: See :attr:`~BaseDataModule.per_device_batch_size`. Sets an upper bound on the aforementioned attribute. fixed_per_device_batch_size: See :attr:`~BaseDataModule.per_device_batch_size`. Setting this value skips the batch size search in :func:`.find_good_per_device_batch_size` which is not recommended for resource efficiency. fixed_per_device_num_workers: See :attr:`~BaseDataModule.per_device_num_workers`. Setting this value skips the num workers search in :func:`.find_good_per_device_num_workers` which is not recommended for resource efficiency. drop_last: See :paramref:`~torch.utils.data.DataLoader.drop_last`. """data_dir:An[str,not_empty()]="${config.data_dir}"device:An[str,one_of("cpu","gpu")]="${config.device}"max_per_device_batch_size:An[int,ge(1)]|None=Nonefixed_per_device_batch_size:An[int,ge(1)]|None=Nonefixed_per_device_num_workers:An[int,ge(0)]|None=Noneshuffle_train_dataset:bool=Trueshuffle_val_dataset:bool=Truedrop_last:bool=False
[docs]classBaseDataModule(LightningDataModule,ABC):"""Base :class:`lightning.pytorch.core.LightningDataModule`. With ``<stage>`` being any of ``train``, ``val``, ``test`` or ``predict``, subclasses need to properly define the ``datasets.<stage>`` attribute(s) for each desired stage. Args: config Attributes: config (BaseDataModuleConfig) datasets (Datasets) collate_fn (typing.Callable): See ``collate_fn`` argument in :class:`torch.utils.data.DataLoader`. pin_memory (bool): Whether to copy tensors into device pinned memory before returning them (is set to ``True`` by default if :paramref:`~BaseDataModuleConfig.device` is ``"gpu"``). per_device_batch_size (int): Per-device number of samples to load per iteration. Temporary value (``1``) is overwritten in :func:`.set_batch_size_and_num_workers`. per_device_num_workers (int): Per-device number of CPU processes to use for data loading (``0`` means that the data will be loaded by each device's assigned CPU process). Temporary value (``0``) is later overwritten in :func:`.set_batch_size_and_num_workers`. """def__init__(self:"BaseDataModule",config:BaseDataModuleConfig)->None:super().__init__()self.config=configself.datasets=Datasets()self.collate_fn=Noneself.pin_memory=self.config.device=="gpu"self.per_device_batch_size=1self.per_device_num_workers=0
[docs]@finaldefstate_dict(self:"BaseDataModule")->dict[str,int]:"""Returns instance attribute values. Returns: A new dictionary containing attribute values :attr:`per_device_batch_size` & :attr:`per_device_num_workers`. """return{"per_device_batch_size":self.per_device_batch_size,"per_device_num_workers":self.per_device_num_workers,}
[docs]@finaldefx_dataloader(self:"BaseDataModule",dataset:Dataset[Tensor]|HFDataset|None,*,shuffle:bool=True,)->DataLoader[Tensor]:"""Generic :class:`torch.utils.data.DataLoader` factory method. Args: dataset: A :mod:`torch` ``Dataset`` to wrap with a :class:`torch.utils.data.DataLoader` shuffle: Whether to shuffle the dataset when iterating over it. Raises: AttributeError: If :paramref:`dataset` is :obj:`None`. Returns: A new :class:`torch.utils.data.DataLoader` instance wrapping the :paramref:`dataset` argument. """ifdatasetisNone:raiseAttributeErrorreturnDataLoader(dataset=dataset,batch_size=self.per_device_batch_size,shuffle=shuffle,num_workers=self.per_device_num_workers,collate_fn=self.collate_fn,pin_memory=self.pin_memory,drop_last=self.config.drop_last,)
[docs]@finaldeftrain_dataloader(self:"BaseDataModule")->DataLoader[Tensor]:"""Calls :meth:`x_dataloader` w/ :attr:`datasets` ``.train``. Returns: A new training :class:`torch.utils.data.DataLoader` instance. """returnself.x_dataloader(dataset=self.datasets.train,shuffle=self.config.shuffle_train_dataset,)
[docs]defval_dataloader(self:"BaseDataModule")->DataLoader[Tensor]:"""Calls :meth:`x_dataloader` w/ :attr:`datasets` ``.val``. Returns: A new validation :class:`torch.utils.data.DataLoader` instance. """returnself.x_dataloader(dataset=self.datasets.val,shuffle=self.config.shuffle_val_dataset,)
[docs]@finaldeftest_dataloader(self:"BaseDataModule")->DataLoader[Tensor]:"""Calls :meth:`x_dataloader` w/ :attr:`datasets` ``.test``. Returns: A new testing :class:`torch.utils.data.DataLoader` instance. """returnself.x_dataloader(dataset=self.datasets.test,shuffle=False)
[docs]@finaldefpredict_dataloader(self:"BaseDataModule")->DataLoader[Tensor]:"""Calls :meth:`x_dataloader` w/ :attr:`datasets` ``.predict``. Returns: A new prediction :class:`torch.utils.data.DataLoader` instance that does not shuffle the dataset. """returnself.x_dataloader(dataset=self.datasets.predict,shuffle=False)