Data

class cellarium.ml.data.AnnDataSchema(adata: AnnData, obs_columns_to_validate: Sequence[str] | None = None)[source]

Bases: object

Store reference AnnData attributes for a collection of distributed AnnData objects.

Validate AnnData objects against reference attributes.

Example:

>>> ref_adata = AnnData(X=np.zeros(2, 3))
>>> ref_adata.obs["batch"] = [0, 1]
>>> ref_adata.var["mu"] = ["a", "b", "c"]
>>> adata = AnnData(X=np.zeros(4, 3))
>>> adata.obs["batch"] = [2, 3, 4, 5]
>>> adata.var["mu"] = ["a", "b", "c"]
>>> schema = AnnDataSchema(ref_adata)
>>> schema.validate_anndata(adata)
Parameters:
  • adata (AnnData) – Reference AnnData object.

  • obs_columns_to_validate (Sequence[str] | None) – Subset of columns to validate in the .obs attribute. If None, all columns are validated.

validate_anndata(adata: AnnData) None[source]

Validate anndata has proper attributes.

Parameters:

adata (AnnData)

Return type:

None

class cellarium.ml.data.DistributedAnnDataCollection(filenames: Sequence[str] | str, limits: Iterable[int] | None = None, shard_size: int | None = None, last_shard_size: int | None = None, max_cache_size: int = 1, cache_size_strictly_enforced: bool = True, label: str | None = None, keys: Sequence[str] | None = None, index_unique: str | None = None, convert: Callable | dict[str, Callable | dict[str, Callable]] | None = None, indices_strict: bool = True, obs_columns_to_validate: Sequence[str] | None = None)[source]

Bases: AnnCollection

Distributed AnnData Collection.

This class is a wrapper around AnnCollection where adatas is a list of LazyAnnData objects.

Underlying anndata files must conform to the same schema (see validate_anndata). The schema is inferred from the first AnnData file in the collection. Individual AnnData files may otherwise vary in the number of cells, and the actual content stored in X, layers, obs and obsm.

Example 1:

>>> dadc = DistributedAnnDataCollection(
...     "gs://bucket-name/folder/adata{000..005}.h5ad",
...     shard_size=10000,  # use if shards are sized evenly
...     max_cache_size=2)

Example 2:

>>> dadc = DistributedAnnDataCollection(
...     "gs://bucket-name/folder/adata{000..005}.h5ad",
...     shard_size=10000,
...     last_shard_size=6000,  # use if the size of the last shard is different
...     max_cache_size=2)

Example 3:

>>> dadc = DistributedAnnDataCollection(
...     "gs://bucket-name/folder/adata{000..005}.h5ad",
...     limits=[500, 1000, 2000, 2500, 3000, 4000],  # use if shards are sized unevenly
...     max_cache_size=2)
Parameters:
  • filenames (Sequence[str] | str) – Names of anndata files.

  • limits (Iterable[int] | None) – List of global cell indices (limits) for the last cells in each shard. If None, the limits are inferred from shard_size and last_shard_size.

  • shard_size (int | None) – The number of cells in each anndata file (shard). Must be specified if the limits is not provided.

  • last_shard_size (int | None) – Last shard size. If not None, the last shard will have this size possibly different from shard_size.

  • max_cache_size (int) – Max size of the cache.

  • cache_size_strictly_enforced (bool) – Assert that the number of retrieved anndatas is not more than maxsize.

  • label (str | None) – Column in obs to place batch information in. If it’s None, no column is added.

  • keys (Sequence[str] | None) – Names for each object being added. These values are used for column values for label or appended to the index if index_unique is not None. If None, keys are set to filenames.

  • index_unique (str | None) – Whether to make the index unique by using the keys. If provided, this is the delimeter between {orig_idx}{index_unique}{key}. When None, the original indices are kept.

  • convert (Callable | dict[str, Callable | dict[str, Callable]] | None) – You can pass a function or a Mapping of functions which will be applied to the values of attributes (obs, obsm, layers, X) or to specific keys of these attributes in the subset object. Specify an attribute and a key (if needed) as keys of the passed Mapping and a function to be applied as a value.

  • indices_strict (bool) – If True, arrays from the subset objects will always have the same order of indices as in selection used to subset. This parameter can be set to False if the order in the returned arrays is not important, for example, when using them for stochastic gradient descent. In this case the performance of subsetting can be a bit better.

  • obs_columns_to_validate (Sequence[str] | None) – Subset of columns to validate in the obs attribute. If None, all columns are validated.

__getitem__(index: slice | int | str | int64 | ndarray | ellipsis | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray | ellipsis] | tuple[slice | int | str | int64 | ndarray | ellipsis, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray, ellipsis] | tuple[ellipsis, slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, ellipsis, slice | int | str | int64 | ndarray] | spmatrix | sparray) AnnData[source]

Materialize and gather anndata files at given indices from the list of lazy anndatas.

LazyAnnData instances corresponding to cells in the index are materialized.

Parameters:

index (slice | int | str | int64 | ndarray | ellipsis | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray | ellipsis] | tuple[slice | int | str | int64 | ndarray | ellipsis, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray, ellipsis] | tuple[ellipsis, slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, ellipsis, slice | int | str | int64 | ndarray] | spmatrix | sparray)

Return type:

AnnData

materialize(adatas_oidx: list[ndarray | None], vidx: slice | int | str | int64 | ndarray) list[AnnData][source]

Buffer and return anndata files at given indices from the list of lazy anndatas.

This efficiently first retrieves cached files and only then caches new files.

Parameters:
  • adatas_oidx (list[ndarray | None])

  • vidx (slice | int | str | int64 | ndarray)

Return type:

list[AnnData]

class cellarium.ml.data.DistributedAnnDataCollectionView(reference, convert, resolved_idx)[source]

Bases: AnnCollectionView

Distributed AnnData Collection View.

This class is a wrapper around AnnCollectionView where adatas is a list of LazyAnnData objects.

property obs_names: Index

Gather and return the obs_names from all AnnData objects in the collection.

class cellarium.ml.data.IterableDistributedAnnDataCollectionDataset(dadc: DistributedAnnDataCollection | AnnData, batch_keys: dict[str, dict[str, AnnDataField] | AnnDataField], batch_size: int = 1, iteration_strategy: Literal['same_order', 'cache_efficient'] = 'cache_efficient', shuffle: bool = False, shuffle_seed: int = 0, drop_last_indices: bool = False, drop_incomplete_batch: bool = False, start_idx: int | None = None, end_idx: int | None = None, worker_seed: int | None = None, test_mode: bool = False)[source]

Bases: IterableDataset

Iterable DistributedAnnDataCollection Dataset.

When shuffle is set to True then the iterator yields datapoints that are uniformly sampled from the entire dataset. Typical use cases include training variational models using the stochastic gradient descent algorithm.

In order to maximize buffer usage, we only shuffle shards and datapoints within individual shards (and not across shards). Therefore, to achieve unbiased pseudo-random uniform sampling, it is imperative that the shards themselves contain datapoints that are uniformly sampled from the entire dataset. If correlations exist between datapoints in a given shard (e.g. all cells coming from the same tissue or experiment), then this assumption is violated. It is the user’s responsibility to prepare appropriately shuffled data shards.

Example:

>>> from cellarium.ml.data import (
...     DistributedAnnDataCollection,
...     IterableDistributedAnnDataCollectionDataset,
... )
>>> from cellarium.ml.utilities.data import AnnDataField, densify

>>> dadc = DistributedAnnDataCollection(
...     "gs://bucket-name/folder/adata{000..005}.h5ad",
...     shard_size=10_000,
...     max_cache_size=2)

>>> dataset = IterableDistributedAnnDataCollectionDataset(
...     dadc,
...     batch_keys={
...         "x_ng": AnnDataField(attr="X", convert_fn=densify),
...         "var_names_g": AnnDataField(attr="var_names"),
...     },
...     batch_size=5000,
...     iteration_strategy="cache_efficient",
...     shuffle=True,
...     shuffle_seed=0,
...     drop_last_indices=True,
... )
Parameters:
  • dadc (DistributedAnnDataCollection | AnnData) – DistributedAnnDataCollection or AnnData from which to load the data.

  • batch_keys (dict[str, dict[str, AnnDataField] | AnnDataField]) – Dictionary that specifies which attributes and keys of the dadc to return in the batch data and how to convert them. Keys must correspond to the input keys of the transforms or the model. Values must be instances of cellarium.ml.utilities.data.AnnDataField.

  • batch_size (int) – How many samples per batch to load.

  • iteration_strategy (Literal['same_order', 'cache_efficient']) – Strategy to use for iterating through the dataset. Options are same_order and cache_efficient. same_order will iterate through the dataset in the same order independent of the number of replicas and workers. cache_efficient will try to minimize the amount of anndata files fetched by each worker.

  • shuffle (bool) – If True, the data is reshuffled at every epoch.

  • shuffle_seed (int) – Random seed used to shuffle the sampler if shuffle=True.

  • drop_last_indices (bool) – If True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas.

  • drop_incomplete_batch (bool) – If True, the dataloader will drop the incomplete batch if the dataset size is not divisible by the batch size.

  • start_idx (int | None) – The starting index of the dataset. If None, then the dataset will start from the first index.

  • end_idx (int | None) – The ending index (exclusive) of the dataset. If None, then the dataset will end at the last index (inclusive).

  • worker_seed (int | None) – Random seed used to seed the workers. If None, then the workers will not be seeded. The seed of the individual worker is computed based on the worker_seed, global worker id, and the epoch. Note that the this seed affects cpu_transforms when they are used. When resuming training, the seed should be set to a different value to ensure that the workers are not seeded with the same seed as the previous run.

  • test_mode (bool) – If True, then tracking of cache and worker informations will be enabled.

__len__() int[source]

Returns the number of batches per replica.

Return type:

int

set_epoch(epoch: int) None[source]

Sets the epoch for the iterator. When shuffle=True, this ensures all replicas use a different random ordering for each epoch.

Parameters:

epoch (int)

Return type:

None

set_resume_step(resume_step: int | None) None[source]

Sets the resume step for the iterator. When resuming from a checkpoint, this ensures that the iterator skips the batches that have already been processed.

Parameters:

resume_step (int | None)

Return type:

None

__getitem__(idx: int | list[int] | slice) dict[str, dict[str, ndarray] | ndarray][source]

Returns a dictionary containing the data from the dadc with keys specified by the batch_keys at the given index idx.

Parameters:

idx (int | list[int] | slice)

Return type:

dict[str, dict[str, ndarray] | ndarray]

__iter__()[source]

Iterate through the dataset by trying to minimize the amount of anndata files fetched by each worker. Iterated indices are evenly divided between replicas (see drop_last_indices).

Note

  1. For both strategies the amount of anndata files fetched is reduced by shuffling the shards first and then the datapoints within the shards.

  2. same_order strategy will iterate through the dataset in the same order independent of the number of replicas and workers.

  3. For cache_efficient strategy the amount of anndata files fetched is further reduced by assigning to each worker a contiguous chunk of the dataset. The returned iterator is determined by the torch.utils.data.get_worker_info() and torch.distributed contexts.

Example 1:

indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
num_replicas=1
batch_size=2
num_workers=3

Same order:

batch idx

0

1

2

3

4

5

indices

(0,1)

(2,3)

(4,5)

(6,7)

(8,9)

(10,11)

worker id

0

1

2

0

1

2

Cache efficient:

batch idx

0

1

2

3

4

5

indices

(0,1)

(4,5)

(8,9)

(2,3)

(6,7)

(10,11)

worker id

0

1

2

0

1

2

Example 2:

indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
num_replicas=1
batch_size=2
num_workers=2

Same order:

batch idx

0

1

2

3

4

5

indices

(0,1)

(2,3)

(4,5)

(6,7)

(8,9)

(10,)

worker id

0

1

0

1

0

1

Cache efficient:

batch idx

0

1

2

3

4

5

indices

(0,1)

(6,7)

(2,3)

(8,9)

(4,5)

(10,)

worker id

0

1

0

1

0

1

Example 3:

indices=[0, 1, 2, 3, 4, 5, 6, 7]
num_replicas=1
batch_size=3
num_workers=2

Same order:

batch idx

0

1

2

indices

(0,1,2)

(3,4,5)

(6,7)

worker id

0

1

0

Cache efficient:

batch idx

0

1

2

indices

(0,1,2)

(6,7)

(3,4,5)

worker id

0

1

0

Example 4:

indices=[0, 1, 2, 3, 4, 5, 6, 7]
num_replicas=1
batch_size=3
drop_incomplete_batch=True
num_workers=2

Same order:

batch idx

0

1

indices

(0,1,2)

(3,4,5)

worker id

0

1

Cache efficient:

batch idx

0

1

indices

(0,1,2)

(3,4,5)

worker id

0

1

Example 5:

indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
num_replicas=2
drop_last_indices=True
batch_size=2
num_workers=1

Same order:

Replica 1

batch idx

0

1

2

indices

(0,2)

(4,6)

(8,)

worker id

0

0

0

Replica 2

batch idx

0

1

2

indices

(1,3)

(5,7)

(9,)

worker id

0

0

0

Cache efficient:

Replica 1

batch idx

0

1

2

indices

(0,1)

(2,3)

(4,)

worker id

0

0

0

Replica 2

batch idx

0

1

2

indices

(5,6)

(7,8)

(9,)

worker id

0

0

0

Example 6:

indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
num_replicas=2
drop_last_indices=False
batch_size=2
num_workers=1

Same order:

Replica 1

batch idx

0

1

2

indices

(0,2)

(4,6)

(8,10)

worker id

0

0

0

Replica 2

batch idx

0

1

2

indices

(1,3)

(5,7)

(9,0)

worker id

0

0

0

Cache efficient:

Replica 1

batch idx

0

1

2

indices

(0,1)

(2,3)

(4,5)

worker id

0

0

0

Replica 2

batch idx

0

1

2

indices

(6,7)

(8,9)

(10,0)

worker id

0

0

0

Resuming from a checkpoint:

  1. For persistent workers the state (epoch and resume_step) is initially set by the load_state_dict() method. At the end of the iteration, the epoch is incremented and the resume_step is set to None.

  2. For non-persistent workers the state is initially set by the load_state_dict() method. The epoch is updated by the on_train_epoch_start hook and the resume_step is set to None by the on_train_epoch_end hook.

  3. If the resume_step is not None, then the worker will skip the batches that have already been processed. The workers are shifted based on the global step.

load_state_dict(state_dict: dict[str, Any]) None[source]

Loads the state of the dataset from the given state dictionary.

Parameters:

state_dict (dict[str, Any]) – State dictionary containing the state of the dataset.

Return type:

None

class cellarium.ml.data.LazyAnnData(filename: str, limits: tuple[int, int], schema: AnnDataSchema, cache: LRU | None = None)[source]

Bases: object

Lazy AnnData backed by a file.

Accessing attributes under lazy_getattr() context returns schema attributes.

Parameters:
  • filename (str) – Name of anndata file.

  • limits (tuple[int, int]) – Limits of cell indices (inclusive, exclusive).

  • schema (AnnDataSchema) – Schema used as a reference for lazy attributes.

  • cache (LRU | None) – Shared LRU cache storing buffered anndatas.

property n_obs: int

Number of observations.

property n_vars: int

Number of variables/features.

property shape: tuple[int, int]

Shape of the data matrix.

property obs_names: Index

Return the observation names.

property cached: bool

Return whether the anndata is cached.

property adata: AnnData

Return backed anndata from the filename

cellarium.ml.data.read_h5ad_file(filename: str, **kwargs) AnnData[source]

Read .h5ad-formatted hdf5 file from a filename.

Parameters:

filename (str) – Path to the data file.

Return type:

AnnData

cellarium.ml.data.read_h5ad_gcs(filename: str, storage_client: Client | None = None) AnnData[source]

Read .h5ad-formatted hdf5 file from the Google Cloud Storage.

Example:

>>> adata = read_h5ad_gcs("gs://dsp-cellarium-cas-public/test-data/test_0.h5ad")
Parameters:
  • filename (str) – Path to the data file in Cloud Storage.

  • storage_client (Client | None)

Return type:

AnnData

cellarium.ml.data.read_h5ad_local(filename: str) AnnData[source]

Read .h5ad-formatted hdf5 file from the local disk.

Parameters:

filename (str) – Path to the local data file.

Return type:

AnnData

cellarium.ml.data.read_h5ad_url(filename: str) AnnData[source]

Read .h5ad-formatted hdf5 file from the URL.

Example:

>>> adata = read_h5ad_url(
...     "https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad"
... )
Parameters:

filename (str) – URL of the data file.

Return type:

AnnData