Source code for cellarium.ml.data.schema

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence

import numpy as np
from anndata import AnnData


[docs] class AnnDataSchema: r""" Store reference AnnData attributes for a collection of distributed AnnData objects. Validate AnnData objects against reference attributes. Example:: >>> ref_adata = AnnData(X=np.zeros(2, 3)) >>> ref_adata.obs["batch"] = [0, 1] >>> ref_adata.var["mu"] = ["a", "b", "c"] >>> adata = AnnData(X=np.zeros(4, 3)) >>> adata.obs["batch"] = [2, 3, 4, 5] >>> adata.var["mu"] = ["a", "b", "c"] >>> schema = AnnDataSchema(ref_adata) >>> schema.validate_anndata(adata) Args: adata: Reference AnnData object. obs_columns_to_validate: Subset of columns to validate in the ``.obs`` attribute. If ``None``, all columns are validated. """ attrs = ["obs", "obsm", "var", "varm", "varp", "var_names", "layers"] def __init__(self, adata: AnnData, obs_columns_to_validate: Sequence[str] | None = None) -> None: self.attr_values = {} for attr in self.attrs: # FIXME: some of the attributes have a reference to the anndata object itself. # This results in anndata object not being garbage collected. self.attr_values[attr] = getattr(adata, attr) self.obs_columns_to_validate = obs_columns_to_validate
[docs] def validate_anndata(self, adata: AnnData) -> None: """Validate anndata has proper attributes.""" for attr in self.attrs: value = getattr(adata, attr) ref_value = self.attr_values[attr] if attr == "obs": if self.obs_columns_to_validate is not None: # Subset the columns to validate ref_value = ref_value[self.obs_columns_to_validate] value = value[self.obs_columns_to_validate] # compare the elements inside the Index object and their order if not ref_value.columns.equals(value.columns): raise ValueError( ".obs attribute columns for anndata passed in does not match .obs attribute columns " "of the reference anndata." ) if not ref_value.dtypes.equals(value.dtypes): for col in ref_value.columns: if ref_value[col].dtype != value[col].dtype: if ref_value[col].dtype == "category": diff = set(ref_value[col].cat.categories).symmetric_difference( set(value[col].cat.categories) ) raise ValueError( f".obs['{col}'].cat.categories for anndata passed in " f"do not match those in the reference anndata. symmetric_differece: {diff}" ) raise ValueError( f".obs['{col}'] dtype for anndata passed in ({value[col].dtype}) " f"does not match .obs['{col}'] dtype of the reference anndata ({ref_value[col].dtype})" ) raise ValueError( f".obs columns for anndata passed in ({value.columns}) do not match .obs columns " f"of the reference anndata ({ref_value.columns})." ) elif attr in ["var", "var_names"]: # For var compare if two DataFrames have the same shape and elements # and the same row/column index. # For var_names compare the elements inside the Index object and their order if not ref_value.equals(value): raise ValueError( f".{attr} attribute for anndata passed in does not match .{attr} attribute " "of the reference anndata." ) elif attr in ["layers", "obsm", "varm", "varp"]: # compare the keys if not set(ref_value.keys()) == set(value.keys()): raise ValueError( f".{attr} attribute keys for anndata passed in does not match .{attr} attribute keys " "of the reference anndata." ) if attr in ["varm", "varp"]: for key in ref_value: arr = value[key] ref_arr = ref_value[key] if not np.array_equal(ref_arr, arr): raise ValueError( f".{attr} attribute for anndata passed in does not match .{attr} attribute " "of the reference anndata." )