# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from typing import Any, Literal
import lightning.pytorch as pl
import torch
from anndata import AnnData
from cellarium.ml.data import DistributedAnnDataCollection, IterableDistributedAnnDataCollectionDataset
from cellarium.ml.utilities.core import train_val_split
from cellarium.ml.utilities.data import AnnDataField, collate_fn
[docs]
class CellariumAnnDataDataModule(pl.LightningDataModule):
"""
DataModule for :class:`~cellarium.ml.data.IterableDistributedAnnDataCollectionDataset`.
Example::
>>> from cellarium.ml import CellariumAnnDataDataModule
>>> from cellarium.ml.data import DistributedAnnDataCollection
>>> from cellarium.ml.utilities.data import AnnDataField, densify
>>> dm = CellariumAnnDataDataModule(
... DistributedAnnDataCollection(
... "gs://bucket-name/folder/adata{000..005}.h5ad",
... shard_size=10_000,
... ),
... max_cache_size=2,
... batch_keys={
... "x_ng": AnnDataField(attr="X", convert_fn=densify),
... "var_names_g": AnnDataField(attr="var_names"),
... },
... batch_size=5000,
... iteration_strategy="cache_efficient",
... shuffle=True,
... shuffle_seed=0,
... drop_last_indices=True,
... num_workers=4,
... )
>>> dm.setup()
>>> for batch in dm.train_dataloader():
... print(batch.keys()) # x_ng, var_names_g
Args:
dadc:
An instance of :class:`~cellarium.ml.data.DistributedAnnDataCollection` or :class:`AnnData`.
batch_keys:
Dictionary that specifies which attributes and keys of the :attr:`dadc` to return
in the batch data and how to convert them. Keys must correspond to
the input keys of the transforms or the model. Values must be instances of
:class:`cellarium.ml.utilities.data.AnnDataField`.
batch_size:
How many samples per batch to load.
iteration_strategy:
Strategy to use for iterating through the dataset. Options are ``same_order`` and ``cache_efficient``.
``same_order`` will iterate through the dataset in the same order independent of the number of replicas
and workers. ``cache_efficient`` will try to minimize the amount of anndata files fetched by each worker.
shuffle:
If ``True``, the data is reshuffled at every epoch.
shuffle_seed:
Random seed used to shuffle the sampler if :attr:`shuffle=True`.
drop_last_indices:
If ``True``, then the sampler will drop the tail of the data
to make it evenly divisible across the number of replicas. If ``False``,
the sampler will add extra indices to make the data evenly divisible across
the replicas.
drop_incomplete_batch:
If ``True``, the dataloader will drop the incomplete batch if the dataset size is not divisible by
the batch size.
train_size:
Size of the train split. If :class:`float`, should be between ``0.0`` and ``1.0`` and represent
the proportion of the dataset to include in the train split. If :class:`int`, represents
the absolute number of train samples. If ``None``, the value is automatically set to the complement
of the ``val_size``.
val_size:
Size of the validation split. If :class:`float`, should be between ``0.0`` and ``1.0`` and represent
the proportion of the dataset to include in the validation split. If :class:`int`, represents
the absolute number of validation samples. If ``None``, the value is set to the complement of
the ``train_size``. If ``train_size`` is also ``None``, it will be set to ``0``.
worker_seed:
Random seed used to seed the workers. If ``None``, then the workers will not be seeded.
The seed of the individual worker is computed based on the ``worker_seed``, global worker id,
and the epoch. Note that the this seed affects ``cpu_transforms`` when they are used.
When resuming training, the seed should be set to a different value to ensure that the
workers are not seeded with the same seed as the previous run.
test_mode:
If ``True`` enables tracking of cache and worker informations.
num_workers:
How many subprocesses to use for data loading. ``0`` means that the data will be loaded in the main process.
prefetch_factor:
Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches
prefetched across all workers. (default value depends on the set value for num_workers. If value of
``num_workers=0`` default is ``None``. Otherwise, if value of ``num_workers > 0`` default is ``2``)
persistent_workers:
If ``True``, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers ``Dataset`` instances alive.
"""
def __init__(
self,
dadc: DistributedAnnDataCollection | AnnData,
# IterableDistributedAnnDataCollectionDataset args
batch_keys: dict[str, dict[str, AnnDataField] | AnnDataField] | None = None,
batch_size: int = 1,
iteration_strategy: Literal["same_order", "cache_efficient"] = "cache_efficient",
shuffle: bool = False,
shuffle_seed: int = 0,
drop_last_indices: bool = False,
drop_incomplete_batch: bool = False,
train_size: float | int | None = None,
val_size: float | int | None = None,
worker_seed: int | None = None,
test_mode: bool = False,
# DataLoader args
num_workers: int = 0,
prefetch_factor: int | None = None,
persistent_workers: bool = False,
) -> None:
super().__init__()
self.save_hyperparameters(logger=False)
# Don't save dadc to the checkpoint
self.hparams["dadc"] = None
self.dadc = dadc
# IterableDistributedAnnDataCollectionDataset args
self.batch_keys = batch_keys or {}
self.batch_size = batch_size
self.iteration_strategy = iteration_strategy
self.shuffle = shuffle
self.shuffle_seed = shuffle_seed
self.drop_last_indices = drop_last_indices
self.n_train, self.n_val = train_val_split(len(dadc), train_size, val_size)
self.worker_seed = worker_seed
self.test_mode = test_mode
# DataLoader args
self.num_workers = num_workers
self.collate_fn = collate_fn
self.drop_incomplete_batch = drop_incomplete_batch
self.prefetch_factor = prefetch_factor
self.persistent_workers = persistent_workers
[docs]
def setup(self, stage: str | None = None) -> None:
"""
.. note::
setup is called from every process across all the nodes. Setting state here is recommended.
.. note::
:attr:`val_dataset` is not shuffled and uses the ``same_order`` iteration strategy.
"""
if stage == "fit":
self.train_dataset = IterableDistributedAnnDataCollectionDataset(
dadc=self.dadc,
batch_keys=self.batch_keys,
batch_size=self.batch_size,
iteration_strategy=self.iteration_strategy,
shuffle=self.shuffle,
shuffle_seed=self.shuffle_seed,
drop_last_indices=self.drop_last_indices,
drop_incomplete_batch=self.drop_incomplete_batch,
worker_seed=self.worker_seed,
test_mode=self.test_mode,
start_idx=0,
end_idx=self.n_train,
)
if stage in {"fit", "validate"}:
self.val_dataset = IterableDistributedAnnDataCollectionDataset(
dadc=self.dadc,
batch_keys=self.batch_keys,
batch_size=self.batch_size,
iteration_strategy="same_order",
shuffle=False,
shuffle_seed=self.shuffle_seed,
drop_last_indices=self.drop_last_indices,
drop_incomplete_batch=self.drop_incomplete_batch,
worker_seed=self.worker_seed,
test_mode=self.test_mode,
start_idx=self.n_train,
end_idx=self.n_train + self.n_val,
)
if stage == "predict":
self.predict_dataset = IterableDistributedAnnDataCollectionDataset(
dadc=self.dadc,
batch_keys=self.batch_keys,
batch_size=self.batch_size,
iteration_strategy=self.iteration_strategy,
shuffle=self.shuffle,
shuffle_seed=self.shuffle_seed,
drop_last_indices=self.drop_last_indices,
drop_incomplete_batch=self.drop_incomplete_batch,
worker_seed=self.worker_seed,
test_mode=self.test_mode,
)
[docs]
def train_dataloader(self) -> torch.utils.data.DataLoader:
"""Training dataloader."""
return torch.utils.data.DataLoader(
self.train_dataset,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
)
[docs]
def val_dataloader(self) -> torch.utils.data.DataLoader:
"""Validation dataloader."""
return torch.utils.data.DataLoader(
self.val_dataset,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
)
[docs]
def predict_dataloader(self) -> torch.utils.data.DataLoader:
"""Prediction dataloader."""
return torch.utils.data.DataLoader(
self.predict_dataset,
num_workers=self.num_workers,
collate_fn=self.collate_fn,
prefetch_factor=self.prefetch_factor,
persistent_workers=self.persistent_workers,
)
def state_dict(self) -> dict[str, Any]:
assert self.trainer is not None
state = {
"iteration_strategy": self.iteration_strategy,
"num_workers": self.num_workers,
"num_replicas": self.trainer.num_devices,
"num_nodes": self.trainer.num_nodes,
"batch_size": self.batch_size,
"accumulate_grad_batches": self.trainer.accumulate_grad_batches,
"shuffle": self.shuffle,
"shuffle_seed": self.shuffle_seed,
"drop_last_indices": self.drop_last_indices,
"drop_incomplete_batch": self.drop_incomplete_batch,
"n_train": self.n_train,
"worker_seed": self.worker_seed,
"epoch": self.trainer.current_epoch,
"resume_step": self.trainer.global_step,
}
return state
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
if hasattr(self, "train_dataset"):
assert self.trainer is not None
if state_dict["iteration_strategy"] != self.iteration_strategy:
raise ValueError(
"Cannot resume training with a different iteration strategy. "
f"Expected {self.iteration_strategy}, got {state_dict['iteration_strategy']}."
)
if state_dict["num_workers"] != self.num_workers:
raise ValueError(
"Cannot resume training with a different number of workers. "
f"Expected {self.num_workers}, got {state_dict['num_workers']}."
)
if state_dict["num_replicas"] != self.trainer.num_devices:
raise ValueError(
"Cannot resume training with a different number of replicas. "
f"Expected {self.trainer.num_devices}, got {state_dict['num_replicas']}."
)
if state_dict["num_nodes"] != self.trainer.num_nodes:
raise ValueError(
"Cannot resume training with a different number of nodes. "
f"Expected {self.trainer.num_nodes}, got {state_dict['num_nodes']}."
)
if state_dict["batch_size"] != self.batch_size:
raise ValueError(
"Cannot resume training with a different batch size. "
f"Expected {self.batch_size}, got {state_dict['batch_size']}."
)
if state_dict["accumulate_grad_batches"] != 1:
raise ValueError("Training with gradient accumulation is not supported when resuming training.")
if state_dict["shuffle"] != self.shuffle:
raise ValueError(
"Cannot resume training with a different shuffle value. "
f"Expected {self.shuffle}, got {state_dict['shuffle']}."
)
if state_dict["shuffle_seed"] != self.shuffle_seed:
raise ValueError(
"Cannot resume training with a different shuffle seed. "
f"Expected {self.shuffle_seed}, got {state_dict['shuffle_seed']}."
)
if state_dict["drop_last_indices"] != self.drop_last_indices:
raise ValueError(
"Cannot resume training with a different drop_last_indices value. "
f"Expected {self.drop_last_indices}, got {state_dict['drop_last_indices']}."
)
if state_dict["drop_incomplete_batch"] != self.drop_incomplete_batch:
raise ValueError(
"Cannot resume training with a different drop_incomplete_batch value. "
f"Expected {self.drop_incomplete_batch}, got {state_dict['drop_incomplete_batch']}."
)
if state_dict["n_train"] != self.n_train:
raise ValueError(
"Cannot resume training with a different train size. "
f"Expected {self.n_train}, got {state_dict['n_train']}."
)
if (self.worker_seed is not None) and (state_dict["worker_seed"] == self.worker_seed):
warnings.warn(
"Resuming training with the same worker seed as the previous run. "
"This may lead to repeated behavior in the workers upon resuming training."
)
self.train_dataset.load_state_dict(state_dict)