Source code for cellarium.ml.transforms.filter

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence
from functools import cache
from typing import Any

import numpy as np
import torch
from torch import nn

from cellarium.ml.utilities.testing import (
    assert_columns_and_array_lengths_equal,
)


[docs] class Filter(nn.Module): """ Filter gene counts by a list of features. When ``ordering=False``, the output columns follow the order genes appear in the input ``var_names_g``: .. math:: \\mathrm{mask}_g = \\mathrm{feature}_g \\in \\mathrm{filter\\_list} y_{ng} = x_{ng}[:, \\mathrm{mask}_g] When ``ordering=True`` (default), the output columns follow the order of :attr:`filter_list`: .. math:: y_{ng} = x_{ng}[:, \\sigma(\\mathrm{filter\\_list})] where :math:`\\sigma` maps each entry in :attr:`filter_list` to its column index in the input. Args: filter_list: A list of features to filter by. ordering: If ``True`` (default), output columns are ordered to match :attr:`filter_list`. If ``False``, output columns follow the order genes appear in the input ``var_names_g``. Use ``ordering=True`` when running inference on data with a different gene ordering than seen during training. allow_missing: If ``True``, genes in :attr:`filter_list` that are absent from the input are zero-filled in the output. Requires ``ordering=True``. If ``False`` (default), all genes in :attr:`filter_list` must be present in the input. """ def __init__(self, filter_list: Sequence[str], ordering: bool = True, allow_missing: bool = False) -> None: super().__init__() self.filter_list = np.array(filter_list) if len(self.filter_list) == 0: raise ValueError(f"`filter_list` must not be empty. Got {self.filter_list}") if allow_missing and not ordering: raise ValueError("`allow_missing=True` requires `ordering=True`.") self.ordering = ordering self.allow_missing = allow_missing
[docs] @cache def filter( self, var_names_g: tuple ) -> ( np.ndarray[Any, np.dtype[np.intp]] | tuple[np.ndarray[Any, np.dtype[np.intp]], np.ndarray[Any, np.dtype[np.intp]]] ): """ Args: var_names_g: The list of the variable names in the input data. Returns: When ``ordering=False``: a 1-D array of source indices in input order. When ``ordering=True`` and ``allow_missing=False``: a 1-D array of source indices ordered to match :attr:`filter_list`. When ``ordering=True`` and ``allow_missing=True``: a tuple ``(src_indices, out_indices)`` where ``src_indices`` indexes columns in ``var_names_g`` and ``out_indices`` gives the corresponding destination column in the output (which always has ``len(filter_list)`` columns). """ if not self.ordering: mask = np.isin(var_names_g, self.filter_list) if not np.any(mask): raise AssertionError("No features in `var_names_g` matched the `filter_list`") return np.where(mask)[0] # ordering=True: iterate over filter_list to enforce its order in the output. var_names_index = {name: idx for idx, name in enumerate(var_names_g)} src_indices: list[int] = [] out_indices: list[int] = [] missing: list[str] = [] for out_idx, gene in enumerate(self.filter_list): if gene in var_names_index: src_indices.append(var_names_index[gene]) out_indices.append(out_idx) else: missing.append(gene) if not src_indices: raise AssertionError("No features in `var_names_g` matched the `filter_list`") if missing and not self.allow_missing: raise AssertionError(f"The following features in `filter_list` were not found in `var_names_g`: {missing}") src_indices_ordered = np.array(src_indices, dtype=np.intp) if self.allow_missing: return src_indices_ordered, np.array(out_indices, dtype=np.intp) return src_indices_ordered
[docs] def forward(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, torch.Tensor | np.ndarray]: """ .. note:: When used with :class:`~cellarium.ml.core.CellariumModule` or :class:`~cellarium.ml.core.CellariumPipeline`, ``x_ng`` and ``var_names_g`` keys in the input dictionary will be overwritten with the filtered values. Args: x_ng: Gene counts. var_names_g: The list of the variable names in the input data. Returns: A dictionary with the following keys: - ``x_ng``: Gene counts filtered (and reordered if ``ordering=True``) to match :attr:`filter_list`. Shape ``(n, len(filter_list))`` when ``ordering=True``, otherwise ``(n, num_matched)``. - ``var_names_g``: Gene names corresponding to the output columns. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) result = self.filter(tuple(var_names_g.tolist())) if self.allow_missing: src_indices, out_indices = result assert isinstance(src_indices, np.ndarray) and isinstance(out_indices, np.ndarray) x_out = torch.zeros(x_ng.shape[0], len(self.filter_list), dtype=x_ng.dtype, device=x_ng.device) x_out[:, out_indices] = x_ng[:, src_indices] x_ng = x_out var_names_g = self.filter_list.copy() elif self.ordering: assert isinstance(result, np.ndarray) x_ng = x_ng[:, result] var_names_g = self.filter_list.copy() else: assert isinstance(result, np.ndarray) ndx = torch.arange(x_ng.shape[0]) x_ng = x_ng[ndx[:, None], result] var_names_g = var_names_g[result] return {"x_ng": x_ng, "var_names_g": var_names_g}
def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" f"filter_list={self.filter_list}, " f"ordering={self.ordering}, " f"allow_missing={self.allow_missing})" )