Source code for cellarium.ml.data.dadc_dataset

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import math
import random
from itertools import islice
from typing import Any, Literal

import numpy as np
import torch
from anndata import AnnData
from boltons.iterutils import chunked_iter
from torch.utils._pytree import tree_map
from torch.utils.data import IterableDataset

from cellarium.ml.data.distributed_anndata import DistributedAnnDataCollection
from cellarium.ml.utilities.data import AnnDataField
from cellarium.ml.utilities.distributed import get_rank_and_num_replicas, get_worker_info


[docs] class IterableDistributedAnnDataCollectionDataset(IterableDataset): r""" Iterable DistributedAnnDataCollection Dataset. When :attr:`shuffle` is set to ``True`` then the iterator yields datapoints that are uniformly sampled from the entire dataset. Typical use cases include training variational models using the stochastic gradient descent algorithm. In order to maximize buffer usage, we only shuffle shards and datapoints within individual shards (and not across shards). Therefore, to achieve unbiased pseudo-random uniform sampling, it is imperative that the shards themselves contain datapoints that are uniformly sampled from the entire dataset. If correlations exist between datapoints in a given shard (e.g. all cells coming from the same tissue or experiment), then this assumption is violated. It is the user's responsibility to prepare appropriately shuffled data shards. Example:: >>> from cellarium.ml.data import ( ... DistributedAnnDataCollection, ... IterableDistributedAnnDataCollectionDataset, ... ) >>> from cellarium.ml.utilities.data import AnnDataField, densify >>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10_000, ... max_cache_size=2) >>> dataset = IterableDistributedAnnDataCollectionDataset( ... dadc, ... batch_keys={ ... "x_ng": AnnDataField(attr="X", convert_fn=densify), ... "var_names_g": AnnDataField(attr="var_names"), ... }, ... batch_size=5000, ... iteration_strategy="cache_efficient", ... shuffle=True, ... shuffle_seed=0, ... drop_last_indices=True, ... ) Args: dadc: DistributedAnnDataCollection or AnnData from which to load the data. batch_keys: Dictionary that specifies which attributes and keys of the :attr:`dadc` to return in the batch data and how to convert them. Keys must correspond to the input keys of the transforms or the model. Values must be instances of :class:`cellarium.ml.utilities.data.AnnDataField`. batch_size: How many samples per batch to load. iteration_strategy: Strategy to use for iterating through the dataset. Options are ``same_order`` and ``cache_efficient``. ``same_order`` will iterate through the dataset in the same order independent of the number of replicas and workers. ``cache_efficient`` will try to minimize the amount of anndata files fetched by each worker. shuffle: If ``True``, the data is reshuffled at every epoch. shuffle_seed: Random seed used to shuffle the sampler if :attr:`shuffle=True`. drop_last_indices: If ``True``, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If ``False``, the sampler will add extra indices to make the data evenly divisible across the replicas. drop_incomplete_batch: If ``True``, the dataloader will drop the incomplete batch if the dataset size is not divisible by the batch size. start_idx: The starting index of the dataset. If ``None``, then the dataset will start from the first index. end_idx: The ending index (exclusive) of the dataset. If ``None``, then the dataset will end at the last index (inclusive). worker_seed: Random seed used to seed the workers. If ``None``, then the workers will not be seeded. The seed of the individual worker is computed based on the ``worker_seed``, global worker id, and the epoch. Note that the this seed affects ``cpu_transforms`` when they are used. When resuming training, the seed should be set to a different value to ensure that the workers are not seeded with the same seed as the previous run. test_mode: If ``True``, then tracking of cache and worker informations will be enabled. """ def __init__( self, dadc: DistributedAnnDataCollection | AnnData, batch_keys: dict[str, dict[str, AnnDataField] | AnnDataField], batch_size: int = 1, iteration_strategy: Literal["same_order", "cache_efficient"] = "cache_efficient", shuffle: bool = False, shuffle_seed: int = 0, drop_last_indices: bool = False, drop_incomplete_batch: bool = False, start_idx: int | None = None, end_idx: int | None = None, worker_seed: int | None = None, test_mode: bool = False, ) -> None: self.dadc = dadc if isinstance(dadc, AnnData): # mimic a DistributedAnnDataCollection self.dadc.limits = [dadc.n_obs] self.batch_keys = batch_keys self.batch_size = batch_size self.iteration_strategy = iteration_strategy self.shuffle = shuffle self.shuffle_seed = shuffle_seed self.drop_last_indices = drop_last_indices self.drop_incomplete_batch = drop_incomplete_batch self.start_idx = 0 if start_idx is None else start_idx self.end_idx = dadc.n_obs if end_idx is None else end_idx self.worker_seed = worker_seed self.epoch = 0 self.resume_step: int | None = None self.test_mode = test_mode
[docs] def __len__(self) -> int: """ Returns the number of batches per replica. """ _, num_replicas = get_rank_and_num_replicas() n_obs = self.end_idx - self.start_idx if self.drop_last_indices and n_obs % num_replicas != 0: # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data. per_replica = n_obs // num_replicas else: per_replica = math.ceil(n_obs / num_replicas) if self.drop_incomplete_batch: batches_per_replica = per_replica // self.batch_size else: batches_per_replica = math.ceil(per_replica / float(self.batch_size)) return batches_per_replica
[docs] def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for the iterator. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. """ self.epoch = epoch
[docs] def set_resume_step(self, resume_step: int | None) -> None: r""" Sets the resume step for the iterator. When resuming from a checkpoint, this ensures that the iterator skips the batches that have already been processed. """ self.resume_step = resume_step
[docs] def __getitem__(self, idx: int | list[int] | slice) -> dict[str, dict[str, np.ndarray] | np.ndarray]: r""" Returns a dictionary containing the data from the :attr:`dadc` with keys specified by the :attr:`batch_keys` at the given index ``idx``. """ data = {} adata = self.dadc[idx] data = tree_map(lambda field: field(adata), self.batch_keys) # for testing purposes if self.test_mode: rank, num_replicas = get_rank_and_num_replicas() worker_id, num_workers = get_worker_info() data["rank"] = np.array([rank]) data["num_replicas"] = np.array([num_replicas]) data["worker_id"] = np.array([worker_id]) data["num_workers"] = np.array([num_workers]) data["miss_count"] = np.array([self.dadc.cache.miss_count]) data["epoch"] = np.array([self.epoch]) return data
[docs] def __iter__(self): r""" Iterate through the dataset by trying to minimize the amount of anndata files fetched by each worker. Iterated indices are evenly divided between replicas (see :attr:`drop_last_indices`). .. note:: 1. For both strategies the amount of anndata files fetched is reduced by shuffling the shards first and then the datapoints within the shards. 2. ``same_order`` strategy will iterate through the dataset in the same order independent of the number of replicas and workers. 3. For ``cache_efficient`` strategy the amount of anndata files fetched is further reduced by assigning to each worker a contiguous chunk of the dataset. The returned iterator is determined by the ``torch.utils.data.get_worker_info()`` and ``torch.distributed`` contexts. **Example 1**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] num_replicas=1 batch_size=2 num_workers=3 Same order: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (2,3) | (4,5) | (6,7) | (8,9) | (10,11) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 2 | 0 | 1 | 2 | +------------+-------+-------+-------+-------+-------+---------+ Cache efficient: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (4,5) | (8,9) | (2,3) | (6,7) | (10,11) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 2 | 0 | 1 | 2 | +------------+-------+-------+-------+-------+-------+---------+ **Example 2**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=1 batch_size=2 num_workers=2 Same order: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (2,3) | (4,5) | (6,7) | (8,9) | (10,) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 0 | 1 | 0 | 1 | +------------+-------+-------+-------+-------+-------+---------+ Cache efficient: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (6,7) | (2,3) | (8,9) | (4,5) | (10,) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 0 | 1 | 0 | 1 | +------------+-------+-------+-------+-------+-------+---------+ **Example 3**:: indices=[0, 1, 2, 3, 4, 5, 6, 7] num_replicas=1 batch_size=3 num_workers=2 Same order: +------------+---------+---------+-------+ | batch idx | 0 | 1 | 2 | +============+=========+=========+=======+ | indices | (0,1,2) | (3,4,5) | (6,7) | +------------+---------+---------+-------+ | worker id | 0 | 1 | 0 | +------------+---------+---------+-------+ Cache efficient: +------------+---------+-------+---------+ | batch idx | 0 | 1 | 2 | +============+=========+=======+=========+ | indices | (0,1,2) | (6,7) | (3,4,5) | +------------+---------+-------+---------+ | worker id | 0 | 1 | 0 | +------------+---------+-------+---------+ **Example 4**:: indices=[0, 1, 2, 3, 4, 5, 6, 7] num_replicas=1 batch_size=3 drop_incomplete_batch=True num_workers=2 Same order: +------------+---------+---------+ | batch idx | 0 | 1 | +============+=========+=========+ | indices | (0,1,2) | (3,4,5) | +------------+---------+---------+ | worker id | 0 | 1 | +------------+---------+---------+ Cache efficient: +------------+---------+---------+ | batch idx | 0 | 1 | +============+=========+=========+ | indices | (0,1,2) | (3,4,5) | +------------+---------+---------+ | worker id | 0 | 1 | +------------+---------+---------+ **Example 5**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=2 drop_last_indices=True batch_size=2 num_workers=1 Same order: *Replica 1* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (0,2) | (4,6) | (8,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ *Replica 2* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (1,3) | (5,7) | (9,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ Cache efficient: *Replica 1* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (0,1) | (2,3) | (4,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ *Replica 2* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (5,6) | (7,8) | (9,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ **Example 6**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=2 drop_last_indices=False batch_size=2 num_workers=1 Same order: *Replica 1* +------------+-------+-------+--------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+========+ | indices | (0,2) | (4,6) | (8,10) | +------------+-------+-------+--------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+--------+ *Replica 2* +------------+-------+-------+-------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+=======+ | indices | (1,3) | (5,7) | (9,0) | +------------+-------+-------+-------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+-------+ Cache efficient: *Replica 1* +------------+-------+-------+-------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+=======+ | indices | (0,1) | (2,3) | (4,5) | +------------+-------+-------+-------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+-------+ *Replica 2* +------------+-------+-------+--------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+========+ | indices | (6,7) | (8,9) | (10,0) | +------------+-------+-------+--------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+--------+ **Resuming from a checkpoint:** 1. For persistent workers the state (:attr:`epoch` and :attr:`resume_step`) is initially set by the :meth:`load_state_dict` method. At the end of the iteration, the :attr:`epoch` is incremented and the :attr:`resume_step` is set to ``None``. 2. For non-persistent workers the state is initially set by the :meth:`load_state_dict` method. The :attr:`epoch` is updated by the ``on_train_epoch_start`` hook and the :attr:`resume_step` is set to ``None`` by the ``on_train_epoch_end`` hook. 3. If the :attr:`resume_step` is not ``None``, then the worker will skip the batches that have already been processed. The workers are shifted based on the global step. """ if self.test_mode and isinstance(self.dadc, DistributedAnnDataCollection): # clear lru cache self.dadc.cache.clear() # replicas rank, num_replicas = get_rank_and_num_replicas() n_obs = self.end_idx - self.start_idx if self.drop_last_indices and n_obs % num_replicas != 0: # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data. per_replica = n_obs // num_replicas else: per_replica = math.ceil(n_obs / num_replicas) total_size = per_replica * num_replicas batches_per_replica = len(self) # workers worker_id, num_workers = get_worker_info() if self.resume_step is not None: num_epochs_that_stepped, num_batches_that_stepped = divmod(self.resume_step, batches_per_replica) # self.epoch can be inconsistent with the global step if checkpointed mid-epoch and not adjusted if self.epoch < num_epochs_that_stepped: raise ValueError( f"Epoch {self.epoch} is less than the number of epochs" f"that have been processed {num_epochs_that_stepped}." ) # shift worker_id based on the global step worker_id = (worker_id - num_batches_that_stepped) % num_workers else: num_batches_that_stepped = 0 # seed workers if self.worker_seed is not None: global_worker_id = self.epoch * (num_replicas * num_workers) + rank * num_workers + worker_id current_worker_seed = self.worker_seed + global_worker_id random.seed(current_worker_seed) np.random.seed(current_worker_seed) torch.manual_seed(current_worker_seed) # indices if self.shuffle: rng = torch.Generator() rng.manual_seed(self.shuffle_seed + self.epoch) limits = [idx for idx in self.dadc.limits if idx > self.start_idx and idx < self.end_idx] iter_limits = list(zip([self.start_idx] + limits, limits + [self.end_idx])) # shuffle shards limit_indices = torch.randperm(len(iter_limits), generator=rng).tolist() indices = [] for limit_idx in limit_indices: lower, upper = iter_limits[limit_idx] # shuffle cells within shards indices.extend((torch.randperm(upper - lower, generator=rng) + lower).tolist()) else: indices = list(range(self.start_idx, self.end_idx)) if not self.drop_last_indices: # add extra samples to make it evenly divisible padding_size = total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:total_size] if self.iteration_strategy == "same_order": # replica indices indices = indices[rank:total_size:num_replicas] if len(indices) != per_replica: raise ValueError( f"The number of indices must be equal to the per_replica size. " f"Got {len(indices)} != {per_replica} at rank {rank}." ) # in python 3.12 `chunked_iter` can be replaced with `itertools.batched` for worker_batch_idx, batch_indices in enumerate( islice(chunked_iter(indices, self.batch_size), worker_id, None, num_workers) ): if self.drop_incomplete_batch and len(batch_indices) < self.batch_size: continue current_batch_idx = worker_batch_idx * num_workers + worker_id if current_batch_idx < num_batches_that_stepped: continue yield self[batch_indices] elif self.iteration_strategy == "cache_efficient": # replica indices indices = indices[rank * per_replica : (rank + 1) * per_replica] if len(indices) != per_replica: raise ValueError( f"The number of indices must be equal to the per_replica size. " f"Got {len(indices)} != {per_replica} at rank {rank}." ) # worker indices batches_per_worker = math.ceil(batches_per_replica / float(num_workers)) per_worker = batches_per_worker * self.batch_size iter_start = worker_id * per_worker iter_end = min(iter_start + per_worker, per_replica) indices = indices[iter_start:iter_end] # in python 3.12 `chunked_iter` can be replaced with `itertools.batched` for worker_batch_idx, batch_indices in enumerate(chunked_iter(indices, self.batch_size)): if self.drop_incomplete_batch and len(batch_indices) < self.batch_size: continue current_batch_idx = worker_batch_idx * num_workers + worker_id if current_batch_idx < num_batches_that_stepped: continue yield self[batch_indices] # Sets epoch and resume_step for persistent workers self.set_epoch(self.epoch + 1) self.set_resume_step(None)
[docs] def load_state_dict(self, state_dict: dict[str, Any]) -> None: r""" Loads the state of the dataset from the given state dictionary. Args: state_dict: State dictionary containing the state of the dataset. """ # trainer.fit_loop.epoch_progress.current.completed self.epoch = state_dict["epoch"] # trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer_steps self.resume_step = state_dict["resume_step"]