Source code for cellarium.ml.core.datamodule

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause


from typing import Literal

import lightning.pytorch as pl
import torch
from anndata import AnnData

from cellarium.ml.data import DistributedAnnDataCollection, IterableDistributedAnnDataCollectionDataset
from cellarium.ml.utilities.core import train_val_split
from cellarium.ml.utilities.data import AnnDataField, collate_fn


[docs] class CellariumAnnDataDataModule(pl.LightningDataModule): """ DataModule for :class:`~cellarium.ml.data.IterableDistributedAnnDataCollectionDataset`. Example:: >>> from cellarium.ml import CellariumAnnDataDataModule >>> from cellarium.ml.data import DistributedAnnDataCollection >>> from cellarium.ml.utilities.data import AnnDataField, densify >>> dm = CellariumAnnDataDataModule( ... DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10_000, ... ), ... max_cache_size=2, ... batch_keys={ ... "x_ng": AnnDataField(attr="X", convert_fn=densify), ... "var_names_g": AnnDataField(attr="var_names"), ... }, ... batch_size=5000, ... iteration_strategy="cache_efficient", ... shuffle=True, ... seed=0, ... drop_last=True, ... num_workers=4, ... ) >>> dm.setup() >>> for batch in dm.train_dataloader(): ... print(batch.keys()) # x_ng, var_names_g Args: dadc: An instance of :class:`~cellarium.ml.data.DistributedAnnDataCollection` or :class:`AnnData`. batch_keys: Dictionary that specifies which attributes and keys of the :attr:`dadc` to return in the batch data and how to convert them. Keys must correspond to the input keys of the transforms or the model. Values must be instances of :class:`cellarium.ml.utilities.data.AnnDataField`. batch_size: How many samples per batch to load. iteration_strategy: Strategy to use for iterating through the dataset. Options are ``same_order`` and ``cache_efficient``. ``same_order`` will iterate through the dataset in the same order independent of the number of replicas and workers. ``cache_efficient`` will try to minimize the amount of anndata files fetched by each worker. shuffle: If ``True``, the data is reshuffled at every epoch. seed: Random seed used to shuffle the sampler if :attr:`shuffle=True`. drop_last_indices: If ``True``, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If ``False``, the sampler will add extra indices to make the data evenly divisible across the replicas. drop_incomplete_batch: If ``True``, the dataloader will drop the incomplete batch if the dataset size is not divisible by the batch size. train_size: Size of the train split. If :class:`float`, should be between ``0.0`` and ``1.0`` and represent the proportion of the dataset to include in the train split. If :class:`int`, represents the absolute number of train samples. If ``None``, the value is automatically set to the complement of the ``val_size``. val_size: Size of the validation split. If :class:`float`, should be between ``0.0`` and ``1.0`` and represent the proportion of the dataset to include in the validation split. If :class:`int`, represents the absolute number of validation samples. If ``None``, the value is set to the complement of the ``train_size``. If ``train_size`` is also ``None``, it will be set to ``0``. test_mode: If ``True`` enables tracking of cache and worker informations. num_workers: How many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process. prefetch_factor: Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of ``num_workers=0`` default is ``None``. Otherwise, if value of ``num_workers > 0`` default is ``2``) persistent_workers: If ``True``, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers ``Dataset`` instances alive. """ def __init__( self, dadc: DistributedAnnDataCollection | AnnData, # IterableDistributedAnnDataCollectionDataset args batch_keys: dict[str, AnnDataField] | None = None, batch_size: int = 1, iteration_strategy: Literal["same_order", "cache_efficient"] = "cache_efficient", shuffle: bool = False, seed: int = 0, drop_last_indices: bool = False, drop_incomplete_batch: bool = False, train_size: float | int | None = None, val_size: float | int | None = None, test_mode: bool = False, # DataLoader args num_workers: int = 0, prefetch_factor: int | None = None, persistent_workers: bool = False, ) -> None: super().__init__() self.save_hyperparameters(logger=False) # Don't save dadc to the checkpoint self.hparams["dadc"] = None self.dadc = dadc # IterableDistributedAnnDataCollectionDataset args self.batch_keys = batch_keys or {} self.batch_size = batch_size self.iteration_strategy = iteration_strategy self.shuffle = shuffle self.seed = seed self.drop_last_indices = drop_last_indices self.n_train, self.n_val = train_val_split(len(dadc), train_size, val_size) self.test_mode = test_mode # DataLoader args self.num_workers = num_workers self.collate_fn = collate_fn self.drop_incomplete_batch = drop_incomplete_batch self.prefetch_factor = prefetch_factor self.persistent_workers = persistent_workers
[docs] def setup(self, stage: str | None = None) -> None: """ .. note:: setup is called from every process across all the nodes. Setting state here is recommended. .. note:: :attr:`val_dataset` is not shuffled and uses the ``same_order`` iteration strategy. """ if stage == "fit": self.train_dataset = IterableDistributedAnnDataCollectionDataset( dadc=self.dadc, batch_keys=self.batch_keys, batch_size=self.batch_size, iteration_strategy=self.iteration_strategy, shuffle=self.shuffle, seed=self.seed, drop_last_indices=self.drop_last_indices, drop_incomplete_batch=self.drop_incomplete_batch, test_mode=self.test_mode, start_idx=0, end_idx=self.n_train, ) self.val_dataset = IterableDistributedAnnDataCollectionDataset( dadc=self.dadc, batch_keys=self.batch_keys, batch_size=self.batch_size, iteration_strategy="same_order", shuffle=False, seed=self.seed, drop_last_indices=self.drop_last_indices, drop_incomplete_batch=self.drop_incomplete_batch, test_mode=self.test_mode, start_idx=self.n_train, end_idx=self.n_train + self.n_val, ) if stage == "predict": self.predict_dataset = IterableDistributedAnnDataCollectionDataset( dadc=self.dadc, batch_keys=self.batch_keys, batch_size=self.batch_size, iteration_strategy=self.iteration_strategy, shuffle=self.shuffle, seed=self.seed, drop_last_indices=self.drop_last_indices, drop_incomplete_batch=self.drop_incomplete_batch, test_mode=self.test_mode, )
[docs] def train_dataloader(self) -> torch.utils.data.DataLoader: """Training dataloader.""" return torch.utils.data.DataLoader( self.train_dataset, num_workers=self.num_workers, collate_fn=self.collate_fn, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers, )
[docs] def val_dataloader(self) -> torch.utils.data.DataLoader: """Validation dataloader.""" return torch.utils.data.DataLoader( self.val_dataset, num_workers=self.num_workers, collate_fn=self.collate_fn, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers, )
[docs] def predict_dataloader(self) -> torch.utils.data.DataLoader: """Prediction dataloader.""" return torch.utils.data.DataLoader( self.predict_dataset, num_workers=self.num_workers, collate_fn=self.collate_fn, prefetch_factor=self.prefetch_factor, persistent_workers=self.persistent_workers, )