Source code for cellarium.ml.data.distributed_anndata

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import gc
from collections.abc import Iterable, Sequence
from contextlib import contextmanager

import numpy as np
import pandas as pd
from anndata import AnnData, concat
from anndata._core.index import _normalize_indices
from anndata.compat import Index, Index1D
from anndata.experimental.multi_files._anncollection import (
    AnnCollection,
    AnnCollectionView,
    ConvertType,
)
from boltons.cacheutils import LRU
from braceexpand import braceexpand

from cellarium.ml.data.fileio import read_h5ad_file
from cellarium.ml.data.schema import AnnDataSchema


class getattr_mode:
    lazy = False


_GETATTR_MODE = getattr_mode()


@contextmanager
def lazy_getattr():
    """
    When in lazy getattr mode, return a :class:`LazyAnnData` object attribute instead of
    the AnnData object attribute.
    """
    try:
        _GETATTR_MODE.lazy = True
        yield
    finally:
        _GETATTR_MODE.lazy = False


[docs] class DistributedAnnDataCollectionView(AnnCollectionView): """ Distributed AnnData Collection View. This class is a wrapper around AnnCollectionView where adatas is a list of :class:`LazyAnnData` objects. """ def __getitem__(self, index: Index) -> "DistributedAnnDataCollectionView": oidx, vidx = _normalize_indices(index, self.obs_names, self.var_names) resolved_idx = self._resolve_idx(oidx, vidx) return DistributedAnnDataCollectionView(self.reference, self.convert, resolved_idx) @property def obs_names(self) -> pd.Index: """ Gather and return the obs_names from all AnnData objects in the collection. """ indices = [] for i, oidx in enumerate(self.adatas_oidx): if oidx is None: continue adata = self.adatas[i] indices.append(adata.obs_names[oidx]) if len(indices) > 1: concat_indices = pd.concat([pd.Series(idx) for idx in indices], ignore_index=True) obs_names = pd.Index(concat_indices) obs_names = obs_names if self.reverse is None else obs_names[self.reverse] else: obs_names = indices[0] return obs_names
[docs] class DistributedAnnDataCollection(AnnCollection): r""" Distributed AnnData Collection. This class is a wrapper around AnnCollection where adatas is a list of LazyAnnData objects. Underlying anndata files must conform to the same schema (see :class:`~cellarium.ml.data.schema.AnnDataSchema.validate_anndata`). The schema is inferred from the first AnnData file in the collection. Individual AnnData files may otherwise vary in the number of cells, and the actual content stored in :attr:`X`, :attr:`layers`, :attr:`obs` and :attr:`obsm`. Example 1:: >>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10000, # use if shards are sized evenly ... max_cache_size=2) Example 2:: >>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... shard_size=10000, ... last_shard_size=6000, # use if the size of the last shard is different ... max_cache_size=2) Example 3:: >>> dadc = DistributedAnnDataCollection( ... "gs://bucket-name/folder/adata{000..005}.h5ad", ... limits=[500, 1000, 2000, 2500, 3000, 4000], # use if shards are sized unevenly ... max_cache_size=2) Args: filenames: Names of anndata files. limits: List of global cell indices (limits) for the last cells in each shard. If ``None``, the limits are inferred from ``shard_size`` and ``last_shard_size``. shard_size: The number of cells in each anndata file (shard). Must be specified if the ``limits`` is not provided. last_shard_size: Last shard size. If not ``None``, the last shard will have this size possibly different from ``shard_size``. max_cache_size: Max size of the cache. cache_size_strictly_enforced: Assert that the number of retrieved anndatas is not more than maxsize. label: Column in :attr:`obs` to place batch information in. If it's ``None``, no column is added. keys: Names for each object being added. These values are used for column values for ``label`` or appended to the index if ``index_unique`` is not ``None``. If ``None``, ``keys`` are set to ``filenames``. index_unique: Whether to make the index unique by using the keys. If provided, this is the delimeter between ``{orig_idx}{index_unique}{key}``. When ``None``, the original indices are kept. convert: You can pass a function or a Mapping of functions which will be applied to the values of attributes (:attr:`obs`, :attr:`obsm`, :attr:`layers`, :attr:`X`) or to specific keys of these attributes in the subset object. Specify an attribute and a key (if needed) as keys of the passed Mapping and a function to be applied as a value. indices_strict: If ``True``, arrays from the subset objects will always have the same order of indices as in selection used to subset. This parameter can be set to ``False`` if the order in the returned arrays is not important, for example, when using them for stochastic gradient descent. In this case the performance of subsetting can be a bit better. obs_columns_to_validate: Subset of columns to validate in the :attr:`obs` attribute. If ``None``, all columns are validated. """ def __init__( self, filenames: Sequence[str] | str, limits: Iterable[int] | None = None, shard_size: int | None = None, last_shard_size: int | None = None, max_cache_size: int = 1, cache_size_strictly_enforced: bool = True, label: str | None = None, keys: Sequence[str] | None = None, index_unique: str | None = None, convert: ConvertType | None = None, indices_strict: bool = True, obs_columns_to_validate: Sequence[str] | None = None, ): self.filenames = list(braceexpand(filenames) if isinstance(filenames, str) else filenames) if (shard_size is None) and (last_shard_size is not None): raise ValueError("If `last_shard_size` is specified then `shard_size` must also be specified.") if limits is None: if shard_size is None: raise ValueError("If `limits` is `None` then `shard_size` must be specified`") limits = [shard_size * (i + 1) for i in range(len(self.filenames))] if last_shard_size is not None: limits[-1] = limits[-1] - shard_size + last_shard_size else: limits = list(limits) if len(limits) != len(self.filenames): raise ValueError( f"The number of points in `limits` ({len(limits)}) must match " f"the number of `filenames` ({len(self.filenames)})." ) # lru cache self.cache = LRU(max_cache_size) self.max_cache_size = max_cache_size self.cache_size_strictly_enforced = cache_size_strictly_enforced # schema adata0 = self.cache[self.filenames[0]] = read_h5ad_file(self.filenames[0]) if len(adata0) != limits[0]: raise ValueError( f"The number of cells in the first anndata file ({len(adata0)}) " f"does not match the first limit ({limits[0]})." ) self.obs_columns_to_validate = obs_columns_to_validate self.schema = AnnDataSchema(adata0, obs_columns_to_validate) # lazy anndatas lazy_adatas = [ LazyAnnData(filename, (start, end), self.schema, self.cache) for start, end, filename in zip([0] + limits, limits, self.filenames) ] # use filenames as default keys if keys is None: keys = self.filenames if len(keys) != len(self.filenames): raise ValueError( f"The number of keys ({len(keys)}) must match the number of `filenames` ({len(filenames)})." ) with lazy_getattr(): super().__init__( adatas=lazy_adatas, join_obs=None, join_obsm=None, join_vars=None, label=label, keys=keys, index_unique=index_unique, convert=convert, harmonize_dtypes=False, indices_strict=indices_strict, )
[docs] def __getitem__(self, index: Index) -> AnnData: """ Materialize and gather anndata files at given indices from the list of lazy anndatas. :class:`LazyAnnData` instances corresponding to cells in the index are materialized. """ oidx, vidx = _normalize_indices(index, self.obs_names, self.var_names) adatas_oidx, oidx, vidx, reverse = self._resolve_idx(oidx, vidx) adatas = self.materialize(adatas_oidx, vidx) adata = concat(adatas, merge="same") adata = adata if reverse is None else adata[reverse] # make sure that categorical dtypes are preserved adata.obs = adata.obs.astype(self.schema.attr_values["obs"].dtypes) return adata
[docs] def materialize(self, adatas_oidx: list[np.ndarray | None], vidx: Index1D) -> list[AnnData]: """ Buffer and return anndata files at given indices from the list of lazy anndatas. This efficiently first retrieves cached files and only then caches new files. """ adata_idx_to_oidx = {i: oidx for i, oidx in enumerate(adatas_oidx) if oidx is not None} n_adatas = len(adata_idx_to_oidx) if self.cache_size_strictly_enforced: if n_adatas > self.max_cache_size: raise ValueError( f"Expected the number of anndata files ({n_adatas}) to be " f"no more than the max cache size ({self.max_cache_size})." ) adatas = [None] * n_adatas # first fetch cached anndata files # this ensures that they are not popped if they were lru for i, (adata_idx, oidx) in enumerate(adata_idx_to_oidx.items()): if self.adatas[adata_idx].cached: adatas[i] = self.adatas[adata_idx].adata[oidx, vidx] # only then cache new anndata files for i, (adata_idx, oidx) in enumerate(adata_idx_to_oidx.items()): if not self.adatas[adata_idx].cached: adatas[i] = self.adatas[adata_idx].adata[oidx, vidx] return adatas
def __repr__(self) -> str: n_obs, n_vars = self.shape descr = f"DistributedAnnDataCollection object with n_obs × n_vars = {self.n_obs} × {self.n_vars}" descr += f"\n constructed from {len(self.filenames)} AnnData objects" for attr, keys in self._view_attrs_keys.items(): if len(keys) > 0: descr += f"\n view of {attr}: {str(keys)[1:-1]}" for attr in self._attrs: keys = list(getattr(self, attr).keys()) if len(keys) > 0: descr += f"\n {attr}: {str(keys)[1:-1]}" if "obs" in self._view_attrs_keys: keys = list(self.obs.keys()) if len(keys) > 0: descr += f"\n own obs: {str(keys)[1:-1]}" return descr def __getstate__(self): state = self.__dict__.copy() del state["cache"] del state["adatas"] del state["obs_names"] del state["schema"] del state["_obs"] return state def __setstate__(self, state): self.__dict__.update(state) self.cache = LRU(self.max_cache_size) adata0 = self.cache[self.filenames[0]] = read_h5ad_file(self.filenames[0]) self.schema = AnnDataSchema(adata0, self.obs_columns_to_validate) self.adatas = [ LazyAnnData(filename, (start, end), self.schema, self.cache) for start, end, filename in zip([0] + self.limits, self.limits, self.filenames) ] self.obs_names = pd.Index([f"cell_{i}" for i in range(self.limits[-1])]) self._obs = pd.DataFrame(index=self.obs_names)
[docs] class LazyAnnData: """ Lazy :class:`~anndata.AnnData` backed by a file. Accessing attributes under :func:`lazy_getattr` context returns schema attributes. Args: filename: Name of anndata file. limits: Limits of cell indices (inclusive, exclusive). schema: Schema used as a reference for lazy attributes. cache: Shared LRU cache storing buffered anndatas. """ _lazy_attrs = ["obs", "obsm", "layers", "var", "varm", "varp", "var_names"] _all_attrs = [ "obs", "var", "uns", "obsm", "varm", "layers", "obsp", "varp", ] def __init__( self, filename: str, limits: tuple[int, int], schema: AnnDataSchema, cache: LRU | None = None, ): self.filename = filename self.limits = limits self.schema = schema if cache is None: cache = LRU() self.cache = cache @property def n_obs(self) -> int: """Number of observations.""" return self.limits[1] - self.limits[0] @property def n_vars(self) -> int: """Number of variables/features.""" return len(self.schema.attr_values["var_names"]) @property def shape(self) -> tuple[int, int]: """Shape of the data matrix.""" return self.n_obs, self.n_vars @property def obs_names(self) -> pd.Index: """Return the observation names.""" if _GETATTR_MODE.lazy: # This is only used during the initialization of DistributedAnnDataCollection return pd.Index([f"cell_{i}" for i in range(*self.limits)]) else: return self.adata.obs_names @property def cached(self) -> bool: """Return whether the anndata is cached.""" return self.filename in self.cache @property def adata(self) -> AnnData: """Return backed anndata from the filename""" try: adata = self.cache[self.filename] except KeyError: # fetch anndata adata = read_h5ad_file(self.filename) # validate anndata if self.n_obs != adata.n_obs: raise ValueError( "Expected `n_obs` for LazyAnnData object and backed anndata to match " f"but found {self.n_obs} and {adata.n_obs}, respectively." ) self.schema.validate_anndata(adata) # cache anndata if len(self.cache) < self.cache.max_size: self.cache[self.filename] = adata else: self.cache[self.filename] = adata # garbage collection of AnnData is not reliable # therefore we call garbage collection manually to free up the memory # https://github.com/scverse/anndata/issues/360 gc.collect() return adata def __getattr__(self, attr): if _GETATTR_MODE.lazy: # This is only used during the initialization of DistributedAnnDataCollection if attr in self._lazy_attrs: return self.schema.attr_values[attr] raise AttributeError(f"Lazy AnnData object has no attribute '{attr}'") else: adata = self.adata if hasattr(adata, attr): return getattr(adata, attr) raise AttributeError(f"Backed AnnData object has no attribute '{attr}'") def __getitem__(self, idx) -> AnnData: return self.adata[idx] def __repr__(self) -> str: if self.cached: buffered = "Cached " else: buffered = "" backed_at = f" backed at {str(self.filename)!r}" descr = f"{buffered}LazyAnnData object with n_obs × n_vars = {self.n_obs} × {self.n_vars}{backed_at}" if self.cached: for attr in self._all_attrs: keys = getattr(self, attr).keys() if len(keys) > 0: descr += f"\n {attr}: {str(list(keys))[1:-1]}" return descr