# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import math
from itertools import islice
from typing import Literal

import numpy as np
import torch
from anndata import AnnData
from boltons.iterutils import chunked_iter
from import IterableDataset

from import DistributedAnnDataCollection
from import AnnDataField
from import get_rank_and_num_replicas, get_worker_info

[docs] class IterableDistributedAnnDataCollectionDataset(IterableDataset): r""" Iterable DistributedAnnDataCollection Dataset. When :attr:`shuffle` is set to ``True`` then the iterator yields datapoints that are uniformly sampled from the entire dataset. Typical use cases include training variational models using the stochastic gradient descent algorithm. In order to maximize buffer usage, we only shuffle shards and datapoints within individual shards (and not across shards). Therefore, to achieve unbiased pseudo-random uniform sampling, it is imperative that the shards themselves contain datapoints that are uniformly sampled from the entire dataset. If correlations exist between datapoints in a given shard (e.g. all cells coming from the same tissue or experiment), then this assumption is violated. It is the user's responsibility to prepare appropriately shuffled data shards. Example:: >>> from import ( ... DistributedAnnDataCollection, ... IterableDistributedAnnDataCollectionDataset, ... ) >>> from import AnnDataField, densify >>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10_000, ... max_cache_size=2) >>> dataset = IterableDistributedAnnDataCollectionDataset( ... dadc, ... batch_keys={ ... "x_ng": AnnDataField(attr="X", convert_fn=densify), ... "var_names_g": AnnDataField(attr="var_names"), ... }, ... batch_size=5000, ... iteration_strategy="cache_efficient", ... shuffle=True, ... seed=0, ... drop_last=True, ... ) Args: dadc: DistributedAnnDataCollection or AnnData from which to load the data. batch_keys: Dictionary that specifies which attributes and keys of the :attr:`dadc` to return in the batch data and how to convert them. Keys must correspond to the input keys of the transforms or the model. Values must be instances of :class:``. batch_size: How many samples per batch to load. iteration_strategy: Strategy to use for iterating through the dataset. Options are ``same_order`` and ``cache_efficient``. ``same_order`` will iterate through the dataset in the same order independent of the number of replicas and workers. ``cache_efficient`` will try to minimize the amount of anndata files fetched by each worker. shuffle: If ``True``, the data is reshuffled at every epoch. seed: Random seed used to shuffle the sampler if :attr:`shuffle=True`. drop_last_indices: If ``True``, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If ``False``, the sampler will add extra indices to make the data evenly divisible across the replicas. drop_incomplete_batch: If ``True``, the dataloader will drop the incomplete batch if the dataset size is not divisible by the batch size. start_idx: The starting index of the dataset. If ``None``, then the dataset will start from the first index. end_idx: The ending index (exclusive) of the dataset. If ``None``, then the dataset will end at the last index (inclusive). test_mode: If ``True``, then tracking of cache and worker informations will be enabled. """ def __init__( self, dadc: DistributedAnnDataCollection | AnnData, batch_keys: dict[str, AnnDataField], batch_size: int = 1, iteration_strategy: Literal["same_order", "cache_efficient"] = "cache_efficient", shuffle: bool = False, seed: int = 0, drop_last_indices: bool = False, drop_incomplete_batch: bool = False, start_idx: int | None = None, end_idx: int | None = None, test_mode: bool = False, ) -> None: self.dadc = dadc if isinstance(dadc, AnnData): # mimic a DistributedAnnDataCollection self.dadc.limits = [dadc.n_obs] self.batch_keys = batch_keys self.batch_size = batch_size self.iteration_strategy = iteration_strategy self.shuffle = shuffle self.seed = seed self.drop_last_indices = drop_last_indices self.drop_incomplete_batch = drop_incomplete_batch self.start_idx = 0 if start_idx is None else start_idx self.end_idx = dadc.n_obs if end_idx is None else end_idx self.epoch = 0 self.test_mode = test_mode
[docs] def __len__(self) -> int: """ Returns the number of batches per replica. """ _, num_replicas = get_rank_and_num_replicas() n_obs = self.end_idx - self.start_idx if self.drop_last_indices and n_obs % num_replicas != 0: # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data. per_replica = n_obs // num_replicas else: per_replica = math.ceil(n_obs / num_replicas) if self.drop_incomplete_batch: batches_per_replica = per_replica // self.batch_size else: batches_per_replica = math.ceil(per_replica / float(self.batch_size)) return batches_per_replica
[docs] def set_epoch(self, epoch: int) -> None: r""" Sets the epoch for the iterator. When :attr:`shuffle=True`, this ensures all replicas use a different random ordering for each epoch. """ self.epoch = epoch
[docs] def __getitem__(self, idx: int | list[int] | slice) -> dict[str, np.ndarray]: r""" Returns a dictionary containing the data from the :attr:`dadc` with keys specified by the :attr:`batch_keys` at the given index ``idx``. """ data = {} adata = self.dadc[idx] for key, field in self.batch_keys.items(): data[key] = field(adata) # for testing purposes if self.test_mode: rank, num_replicas = get_rank_and_num_replicas() worker_id, num_workers = get_worker_info() data["rank"] = np.array([rank]) data["num_replicas"] = np.array([num_replicas]) data["worker_id"] = np.array([worker_id]) data["num_workers"] = np.array([num_workers]) data["miss_count"] = np.array([self.dadc.cache.miss_count]) data["epoch"] = np.array([self.epoch]) return data
[docs] def __iter__(self): r""" Iterate through the dataset by trying to minimize the amount of anndata files fetched by each worker. Iterated indices are evenly divided between replicas (see :attr:`drop_last_indices`). .. note:: 1. For both strategies the amount of anndata files fetched is reduced by shuffling the shards first and then the datapoints within the shards. 2. ``same_order`` strategy will iterate through the dataset in the same order independent of the number of replicas and workers. 3. For ``cache_efficient`` strategy the amount of anndata files fetched is further reduced by assigning to each worker a contiguous chunk of the dataset. The returned iterator is determined by the ```` and ``torch.distributed`` contexts. **Example 1**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] num_replicas=1 batch_size=2 num_workers=3 Same order: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (2,3) | (4,5) | (6,7) | (8,9) | (10,11) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 2 | 0 | 1 | 2 | +------------+-------+-------+-------+-------+-------+---------+ Cache efficient: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (4,5) | (8,9) | (2,3) | (6,7) | (10,11) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 2 | 0 | 1 | 2 | +------------+-------+-------+-------+-------+-------+---------+ **Example 2**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=1 batch_size=2 num_workers=2 Same order: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (2,3) | (4,5) | (6,7) | (8,9) | (10,) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 0 | 1 | 0 | 1 | +------------+-------+-------+-------+-------+-------+---------+ Cache efficient: +------------+-------+-------+-------+-------+-------+---------+ | batch idx | 0 | 1 | 2 | 3 | 4 | 5 | +============+=======+=======+=======+=======+=======+=========+ | indices | (0,1) | (6,7) | (2,3) | (8,9) | (4,5) | (10,) | +------------+-------+-------+-------+-------+-------+---------+ | worker id | 0 | 1 | 0 | 1 | 0 | 1 | +------------+-------+-------+-------+-------+-------+---------+ **Example 3**:: indices=[0, 1, 2, 3, 4, 5, 6, 7] num_replicas=1 batch_size=3 num_workers=2 Same order: +------------+---------+---------+-------+ | batch idx | 0 | 1 | 2 | +============+=========+=========+=======+ | indices | (0,1,2) | (3,4,5) | (6,7) | +------------+---------+---------+-------+ | worker id | 0 | 1 | 0 | +------------+---------+---------+-------+ Cache efficient: +------------+---------+-------+---------+ | batch idx | 0 | 1 | 2 | +============+=========+=======+=========+ | indices | (0,1,2) | (6,7) | (3,4,5) | +------------+---------+-------+---------+ | worker id | 0 | 1 | 0 | +------------+---------+-------+---------+ **Example 4**:: indices=[0, 1, 2, 3, 4, 5, 6, 7] num_replicas=1 batch_size=3 drop_incomplete_batch=True num_workers=2 Same order: +------------+---------+---------+ | batch idx | 0 | 1 | +============+=========+=========+ | indices | (0,1,2) | (3,4,5) | +------------+---------+---------+ | worker id | 0 | 1 | +------------+---------+---------+ Cache efficient: +------------+---------+---------+ | batch idx | 0 | 2 | +============+=========+=========+ | indices | (0,1,2) | (3,4,5) | +------------+---------+---------+ | worker id | 0 | 0 | +------------+---------+---------+ **Example 5**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=2 drop_last_indices=True batch_size=2 num_workers=1 Same order: *Replica 1* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (0,2) | (4,6) | (8,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ *Replica 2* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (1,3) | (5,7) | (9,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ Cache efficient: *Replica 1* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (0,1) | (2,3) | (4,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ *Replica 2* +------------+-------+-------+------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+======+ | indices | (5,6) | (7,8) | (9,) | +------------+-------+-------+------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+------+ **Example 6**:: indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=2 drop_last_indices=False batch_size=2 num_workers=1 Same order: *Replica 1* +------------+-------+-------+--------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+========+ | indices | (0,2) | (4,6) | (8,10) | +------------+-------+-------+--------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+--------+ *Replica 2* +------------+-------+-------+-------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+=======+ | indices | (1,3) | (5,7) | (9,0) | +------------+-------+-------+-------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+-------+ Cache efficient: *Replica 1* +------------+-------+-------+-------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+=======+ | indices | (0,1) | (2,3) | (4,5) | +------------+-------+-------+-------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+-------+ *Replica 2* +------------+-------+-------+--------+ | batch idx | 0 | 1 | 2 | +============+=======+=======+========+ | indices | (6,7) | (8,9) | (10,0) | +------------+-------+-------+--------+ | worker id | 0 | 0 | 0 | +------------+-------+-------+--------+ """ if self.test_mode and isinstance(self.dadc, DistributedAnnDataCollection): # clear lru cache self.dadc.cache.clear() # replicas rank, num_replicas = get_rank_and_num_replicas() n_obs = self.end_idx - self.start_idx if self.drop_last_indices and n_obs % num_replicas != 0: # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data. per_replica = n_obs // num_replicas else: per_replica = math.ceil(n_obs / num_replicas) total_size = per_replica * num_replicas # workers worker_id, num_workers = get_worker_info() # indices if self.shuffle: rng = torch.Generator() rng.manual_seed(self.seed + self.epoch) limits = [idx for idx in self.dadc.limits if idx > self.start_idx and idx < self.end_idx] iter_limits = list(zip([self.start_idx] + limits, limits + [self.end_idx])) # shuffle shards limit_indices = torch.randperm(len(iter_limits), generator=rng).tolist() indices = [] for limit_idx in limit_indices: lower, upper = iter_limits[limit_idx] # shuffle cells within shards indices.extend((torch.randperm(upper - lower, generator=rng) + lower).tolist()) else: indices = list(range(self.start_idx, self.end_idx)) if not self.drop_last_indices: # add extra samples to make it evenly divisible padding_size = total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:total_size] if self.iteration_strategy == "same_order": # replica indices indices = indices[rank:total_size:num_replicas] if len(indices) != per_replica: raise ValueError( f"The number of indices must be equal to the per_replica size. " f"Got {len(indices)} != {per_replica} at rank {rank}." ) # in python 3.12 `chunked_iter` can be replaced with `itertools.batched` for batch_indices in islice(chunked_iter(indices, self.batch_size), worker_id, None, num_workers): if self.drop_incomplete_batch and len(batch_indices) < self.batch_size: continue yield self[batch_indices] elif self.iteration_strategy == "cache_efficient": # replica indices indices = indices[rank * per_replica : (rank + 1) * per_replica] if len(indices) != per_replica: raise ValueError( f"The number of indices must be equal to the per_replica size. " f"Got {len(indices)} != {per_replica} at rank {rank}." ) # worker indices batches_per_replica = math.ceil(per_replica / float(self.batch_size)) batches_per_worker = math.ceil(batches_per_replica / float(num_workers)) per_worker = batches_per_worker * self.batch_size iter_start = worker_id * per_worker iter_end = min(iter_start + per_worker, per_replica) indices = indices[iter_start:iter_end] # in python 3.12 `chunked_iter` can be replaced with `itertools.batched` for batch_indices in chunked_iter(indices, self.batch_size): if self.drop_incomplete_batch and len(batch_indices) < self.batch_size: continue yield self[batch_indices] # Sets epoch for persistent workers self.set_epoch(self.epoch + 1)