Data
- class cellarium.ml.data.AnnDataSchema(adata: AnnData, obs_columns_to_validate: Sequence[str] | None = None)[source]
Bases:
object
Store reference AnnData attributes for a collection of distributed AnnData objects.
Validate AnnData objects against reference attributes.
Example:
>>> ref_adata = AnnData(X=np.zeros(2, 3)) >>> ref_adata.obs["batch"] = [0, 1] >>> ref_adata.var["mu"] = ["a", "b", "c"] >>> adata = AnnData(X=np.zeros(4, 3)) >>> adata.obs["batch"] = [2, 3, 4, 5] >>> adata.var["mu"] = ["a", "b", "c"] >>> schema = AnnDataSchema(ref_adata) >>> schema.validate_anndata(adata)
- Parameters:
adata (AnnData) – Reference AnnData object.
obs_columns_to_validate (Sequence[str] | None) – Subset of columns to validate in the
.obs
attribute. IfNone
, all columns are validated.
- class cellarium.ml.data.DistributedAnnDataCollection(filenames: Sequence[str] | str, limits: Iterable[int] | None = None, shard_size: int | None = None, last_shard_size: int | None = None, max_cache_size: int = 1, cache_size_strictly_enforced: bool = True, label: str | None = None, keys: Sequence[str] | None = None, index_unique: str | None = None, convert: Callable | dict[str, Callable | dict[str, Callable]] | None = None, indices_strict: bool = True, obs_columns_to_validate: Sequence[str] | None = None)[source]
Bases:
AnnCollection
Distributed AnnData Collection.
This class is a wrapper around AnnCollection where adatas is a list of LazyAnnData objects.
Underlying anndata files must conform to the same schema (see
validate_anndata
). The schema is inferred from the first AnnData file in the collection. Individual AnnData files may otherwise vary in the number of cells, and the actual content stored inX
,layers
,obs
andobsm
.Example 1:
>>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10000, # use if shards are sized evenly ... max_cache_size=2)
Example 2:
>>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10000, ... last_shard_size=6000, # use if the size of the last shard is different ... max_cache_size=2)
Example 3:
>>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... limits=[500, 1000, 2000, 2500, 3000, 4000], # use if shards are sized unevenly ... max_cache_size=2)
- Parameters:
filenames (Sequence[str] | str) – Names of anndata files.
limits (Iterable[int] | None) – List of global cell indices (limits) for the last cells in each shard. If
None
, the limits are inferred fromshard_size
andlast_shard_size
.shard_size (int | None) – The number of cells in each anndata file (shard). Must be specified if the
limits
is not provided.last_shard_size (int | None) – Last shard size. If not
None
, the last shard will have this size possibly different fromshard_size
.max_cache_size (int) – Max size of the cache.
cache_size_strictly_enforced (bool) – Assert that the number of retrieved anndatas is not more than maxsize.
label (str | None) – Column in
obs
to place batch information in. If it’sNone
, no column is added.keys (Sequence[str] | None) – Names for each object being added. These values are used for column values for
label
or appended to the index ifindex_unique
is notNone
. IfNone
,keys
are set tofilenames
.index_unique (str | None) – Whether to make the index unique by using the keys. If provided, this is the delimeter between
{orig_idx}{index_unique}{key}
. WhenNone
, the original indices are kept.convert (Callable | dict[str, Callable | dict[str, Callable]] | None) – You can pass a function or a Mapping of functions which will be applied to the values of attributes (
obs
,obsm
,layers
,X
) or to specific keys of these attributes in the subset object. Specify an attribute and a key (if needed) as keys of the passed Mapping and a function to be applied as a value.indices_strict (bool) – If
True
, arrays from the subset objects will always have the same order of indices as in selection used to subset. This parameter can be set toFalse
if the order in the returned arrays is not important, for example, when using them for stochastic gradient descent. In this case the performance of subsetting can be a bit better.obs_columns_to_validate (Sequence[str] | None) – Subset of columns to validate in the
obs
attribute. IfNone
, all columns are validated.
- __getitem__(index: slice | int | str | int64 | ndarray | ellipsis | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray | ellipsis] | tuple[slice | int | str | int64 | ndarray | ellipsis, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray, ellipsis] | tuple[ellipsis, slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, ellipsis, slice | int | str | int64 | ndarray] | spmatrix | sparray) AnnData [source]
Materialize and gather anndata files at given indices from the list of lazy anndatas.
LazyAnnData
instances corresponding to cells in the index are materialized.- Parameters:
index (slice | int | str | int64 | ndarray | ellipsis | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray | ellipsis] | tuple[slice | int | str | int64 | ndarray | ellipsis, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray, ellipsis] | tuple[ellipsis, slice | int | str | int64 | ndarray, slice | int | str | int64 | ndarray] | tuple[slice | int | str | int64 | ndarray, ellipsis, slice | int | str | int64 | ndarray] | spmatrix | sparray)
- Return type:
AnnData
- materialize(adatas_oidx: list[ndarray | None], vidx: slice | int | str | int64 | ndarray) list[AnnData] [source]
Buffer and return anndata files at given indices from the list of lazy anndatas.
This efficiently first retrieves cached files and only then caches new files.
- Parameters:
adatas_oidx (list[ndarray | None])
vidx (slice | int | str | int64 | ndarray)
- Return type:
list[AnnData]
- class cellarium.ml.data.DistributedAnnDataCollectionView(reference, convert, resolved_idx)[source]
Bases:
AnnCollectionView
Distributed AnnData Collection View.
This class is a wrapper around AnnCollectionView where adatas is a list of
LazyAnnData
objects.- property obs_names: Index
Gather and return the obs_names from all AnnData objects in the collection.
- class cellarium.ml.data.IterableDistributedAnnDataCollectionDataset(dadc: DistributedAnnDataCollection | AnnData, batch_keys: dict[str, dict[str, AnnDataField] | AnnDataField], batch_size: int = 1, iteration_strategy: Literal['same_order', 'cache_efficient'] = 'cache_efficient', shuffle: bool = False, shuffle_seed: int = 0, drop_last_indices: bool = False, drop_incomplete_batch: bool = False, start_idx: int | None = None, end_idx: int | None = None, worker_seed: int | None = None, test_mode: bool = False)[source]
Bases:
IterableDataset
Iterable DistributedAnnDataCollection Dataset.
When
shuffle
is set toTrue
then the iterator yields datapoints that are uniformly sampled from the entire dataset. Typical use cases include training variational models using the stochastic gradient descent algorithm.In order to maximize buffer usage, we only shuffle shards and datapoints within individual shards (and not across shards). Therefore, to achieve unbiased pseudo-random uniform sampling, it is imperative that the shards themselves contain datapoints that are uniformly sampled from the entire dataset. If correlations exist between datapoints in a given shard (e.g. all cells coming from the same tissue or experiment), then this assumption is violated. It is the user’s responsibility to prepare appropriately shuffled data shards.
Example:
>>> from cellarium.ml.data import ( ... DistributedAnnDataCollection, ... IterableDistributedAnnDataCollectionDataset, ... ) >>> from cellarium.ml.utilities.data import AnnDataField, densify >>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10_000, ... max_cache_size=2) >>> dataset = IterableDistributedAnnDataCollectionDataset( ... dadc, ... batch_keys={ ... "x_ng": AnnDataField(attr="X", convert_fn=densify), ... "var_names_g": AnnDataField(attr="var_names"), ... }, ... batch_size=5000, ... iteration_strategy="cache_efficient", ... shuffle=True, ... shuffle_seed=0, ... drop_last_indices=True, ... )
- Parameters:
dadc (DistributedAnnDataCollection | AnnData) – DistributedAnnDataCollection or AnnData from which to load the data.
batch_keys (dict[str, dict[str, AnnDataField] | AnnDataField]) – Dictionary that specifies which attributes and keys of the
dadc
to return in the batch data and how to convert them. Keys must correspond to the input keys of the transforms or the model. Values must be instances ofcellarium.ml.utilities.data.AnnDataField
.batch_size (int) – How many samples per batch to load.
iteration_strategy (Literal['same_order', 'cache_efficient']) – Strategy to use for iterating through the dataset. Options are
same_order
andcache_efficient
.same_order
will iterate through the dataset in the same order independent of the number of replicas and workers.cache_efficient
will try to minimize the amount of anndata files fetched by each worker.shuffle (bool) – If
True
, the data is reshuffled at every epoch.shuffle_seed (int) – Random seed used to shuffle the sampler if
shuffle=True
.drop_last_indices (bool) – If
True
, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. IfFalse
, the sampler will add extra indices to make the data evenly divisible across the replicas.drop_incomplete_batch (bool) – If
True
, the dataloader will drop the incomplete batch if the dataset size is not divisible by the batch size.start_idx (int | None) – The starting index of the dataset. If
None
, then the dataset will start from the first index.end_idx (int | None) – The ending index (exclusive) of the dataset. If
None
, then the dataset will end at the last index (inclusive).worker_seed (int | None) – Random seed used to seed the workers. If
None
, then the workers will not be seeded. The seed of the individual worker is computed based on theworker_seed
, global worker id, and the epoch. Note that the this seed affectscpu_transforms
when they are used. When resuming training, the seed should be set to a different value to ensure that the workers are not seeded with the same seed as the previous run.test_mode (bool) – If
True
, then tracking of cache and worker informations will be enabled.
- set_epoch(epoch: int) None [source]
Sets the epoch for the iterator. When
shuffle=True
, this ensures all replicas use a different random ordering for each epoch.- Parameters:
epoch (int)
- Return type:
None
- set_resume_step(resume_step: int | None) None [source]
Sets the resume step for the iterator. When resuming from a checkpoint, this ensures that the iterator skips the batches that have already been processed.
- Parameters:
resume_step (int | None)
- Return type:
None
- __getitem__(idx: int | list[int] | slice) dict[str, dict[str, ndarray] | ndarray] [source]
Returns a dictionary containing the data from the
dadc
with keys specified by thebatch_keys
at the given indexidx
.- Parameters:
idx (int | list[int] | slice)
- Return type:
dict[str, dict[str, ndarray] | ndarray]
- __iter__()[source]
Iterate through the dataset by trying to minimize the amount of anndata files fetched by each worker. Iterated indices are evenly divided between replicas (see
drop_last_indices
).Note
For both strategies the amount of anndata files fetched is reduced by shuffling the shards first and then the datapoints within the shards.
same_order
strategy will iterate through the dataset in the same order independent of the number of replicas and workers.For
cache_efficient
strategy the amount of anndata files fetched is further reduced by assigning to each worker a contiguous chunk of the dataset. The returned iterator is determined by thetorch.utils.data.get_worker_info()
andtorch.distributed
contexts.
Example 1:
indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] num_replicas=1 batch_size=2 num_workers=3
Same order:
batch idx
0
1
2
3
4
5
indices
(0,1)
(2,3)
(4,5)
(6,7)
(8,9)
(10,11)
worker id
0
1
2
0
1
2
Cache efficient:
batch idx
0
1
2
3
4
5
indices
(0,1)
(4,5)
(8,9)
(2,3)
(6,7)
(10,11)
worker id
0
1
2
0
1
2
Example 2:
indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=1 batch_size=2 num_workers=2
Same order:
batch idx
0
1
2
3
4
5
indices
(0,1)
(2,3)
(4,5)
(6,7)
(8,9)
(10,)
worker id
0
1
0
1
0
1
Cache efficient:
batch idx
0
1
2
3
4
5
indices
(0,1)
(6,7)
(2,3)
(8,9)
(4,5)
(10,)
worker id
0
1
0
1
0
1
Example 3:
indices=[0, 1, 2, 3, 4, 5, 6, 7] num_replicas=1 batch_size=3 num_workers=2
Same order:
batch idx
0
1
2
indices
(0,1,2)
(3,4,5)
(6,7)
worker id
0
1
0
Cache efficient:
batch idx
0
1
2
indices
(0,1,2)
(6,7)
(3,4,5)
worker id
0
1
0
Example 4:
indices=[0, 1, 2, 3, 4, 5, 6, 7] num_replicas=1 batch_size=3 drop_incomplete_batch=True num_workers=2
Same order:
batch idx
0
1
indices
(0,1,2)
(3,4,5)
worker id
0
1
Cache efficient:
batch idx
0
1
indices
(0,1,2)
(3,4,5)
worker id
0
1
Example 5:
indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=2 drop_last_indices=True batch_size=2 num_workers=1
Same order:
Replica 1
batch idx
0
1
2
indices
(0,2)
(4,6)
(8,)
worker id
0
0
0
Replica 2
batch idx
0
1
2
indices
(1,3)
(5,7)
(9,)
worker id
0
0
0
Cache efficient:
Replica 1
batch idx
0
1
2
indices
(0,1)
(2,3)
(4,)
worker id
0
0
0
Replica 2
batch idx
0
1
2
indices
(5,6)
(7,8)
(9,)
worker id
0
0
0
Example 6:
indices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] num_replicas=2 drop_last_indices=False batch_size=2 num_workers=1
Same order:
Replica 1
batch idx
0
1
2
indices
(0,2)
(4,6)
(8,10)
worker id
0
0
0
Replica 2
batch idx
0
1
2
indices
(1,3)
(5,7)
(9,0)
worker id
0
0
0
Cache efficient:
Replica 1
batch idx
0
1
2
indices
(0,1)
(2,3)
(4,5)
worker id
0
0
0
Replica 2
batch idx
0
1
2
indices
(6,7)
(8,9)
(10,0)
worker id
0
0
0
Resuming from a checkpoint:
For persistent workers the state (
epoch
andresume_step
) is initially set by theload_state_dict()
method. At the end of the iteration, theepoch
is incremented and theresume_step
is set toNone
.For non-persistent workers the state is initially set by the
load_state_dict()
method. Theepoch
is updated by theon_train_epoch_start
hook and theresume_step
is set toNone
by theon_train_epoch_end
hook.If the
resume_step
is notNone
, then the worker will skip the batches that have already been processed. The workers are shifted based on the global step.
- class cellarium.ml.data.LazyAnnData(filename: str, limits: tuple[int, int], schema: AnnDataSchema, cache: LRU | None = None)[source]
Bases:
object
Lazy
AnnData
backed by a file.Accessing attributes under
lazy_getattr()
context returns schema attributes.- Parameters:
filename (str) – Name of anndata file.
limits (tuple[int, int]) – Limits of cell indices (inclusive, exclusive).
schema (AnnDataSchema) – Schema used as a reference for lazy attributes.
cache (LRU | None) – Shared LRU cache storing buffered anndatas.
- property n_obs: int
Number of observations.
- property n_vars: int
Number of variables/features.
- property shape: tuple[int, int]
Shape of the data matrix.
- property obs_names: Index
Return the observation names.
- property cached: bool
Return whether the anndata is cached.
- property adata: AnnData
Return backed anndata from the filename
- cellarium.ml.data.read_h5ad_file(filename: str, **kwargs) AnnData [source]
Read
.h5ad
-formatted hdf5 file from a filename.- Parameters:
filename (str) – Path to the data file.
- Return type:
AnnData
- cellarium.ml.data.read_h5ad_gcs(filename: str, storage_client: Client | None = None) AnnData [source]
Read
.h5ad
-formatted hdf5 file from the Google Cloud Storage.Example:
>>> adata = read_h5ad_gcs("gs://dsp-cellarium-cas-public/test-data/test_0.h5ad")
- Parameters:
filename (str) – Path to the data file in Cloud Storage.
storage_client (Client | None)
- Return type:
AnnData
- cellarium.ml.data.read_h5ad_local(filename: str) AnnData [source]
Read
.h5ad
-formatted hdf5 file from the local disk.- Parameters:
filename (str) – Path to the local data file.
- Return type:
AnnData
- cellarium.ml.data.read_h5ad_url(filename: str) AnnData [source]
Read
.h5ad
-formatted hdf5 file from the URL.Example:
>>> adata = read_h5ad_url( ... "https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_0.h5ad" ... )
- Parameters:
filename (str) – URL of the data file.
- Return type:
AnnData