Source code for cellarium.ml.models.socam

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from typing import TypedDict

import lightning.pytorch as pl
import numpy as np
import torch
import torch.nn.functional

from cellarium.ml.models.model import CellariumModel, PredictMixin, ValidateMixin
from cellarium.ml.utilities.testing import (
    assert_arrays_equal,
    assert_columns_and_array_lengths_equal,
)


class NonleafInfo(TypedDict):
    nonleaf_desc_cc: torch.Tensor
    perm: torch.Tensor
    inv_perm: torch.Tensor


def _expand_with_ancestors(
    cl_name_subset: list[str],
    cl_names: list[str],
    descendant_tensor: torch.Tensor,
) -> list[str]:
    """Return ``cl_name_subset`` extended to include every ancestor of each node.

    ``descendant_tensor[k, j] == 1`` means j is a descendant of k, equivalently k
    is an ancestor of j.  So the ancestors of j are all k where column j is 1.

    Args:
        cl_name_subset: Category names requested by the user.
        cl_names: Full ordered list of category names (rows/cols of ``descendant_tensor``).
        descendant_tensor: Binary ``(C, C)`` tensor on any device.

    Returns:
        Sorted list containing every name in ``cl_name_subset`` plus all of their
        ancestors that are present in ``cl_names``.
    """
    index_map = {cat: i for i, cat in enumerate(cl_names)}
    expanded: set[int] = set()
    for name in cl_name_subset:
        j = index_map[name]
        # column j: entry k==1 means k is an ancestor of j (including j itself)
        ancestor_indices = (descendant_tensor[:, j] > 0).nonzero(as_tuple=True)[0].tolist()
        expanded.update(ancestor_indices)
    return sorted(cl_names[i] for i in expanded)


def _build_nonleaf_info(desc_matrix_cc: torch.Tensor) -> NonleafInfo:
    """Precompute the non-leaf descriptor tensors needed by ``propagate_logits``.

    Produces a permutation that places non-leaf categories first and leaves last,
    so the hot path can use ``torch.cat`` and simple slices instead of boolean-mask
    assignment.  Boolean-mask assignment internally calls ``aten.nonzero`` which
    ``torch.compile`` / inductor cannot fuse; integer-index gather and ``torch.cat``
    have no such restriction.

    Runs once per unique ``desc_matrix_cc`` (construction time, not the hot path).

    Args:
        desc_matrix_cc: Binary ``(c, c)`` descendant tensor.

    Returns:
        A :class:`NonleafInfo` dict with keys ``nonleaf_desc_cc``, ``perm``, and
        ``inv_perm``.  ``nonleaf_desc_cc`` is ``(c_nonleaf, c)`` with columns in
        ``perm`` order; ``perm`` places non-leaf indices first, leaf last;
        ``inv_perm`` restores the original column order.
    """
    nonleaf_mask = desc_matrix_cc.sum(dim=1) > 1
    nonleaf_indices = nonleaf_mask.nonzero(as_tuple=True)[0]  # (c_nonleaf,)
    leaf_indices = (~nonleaf_mask).nonzero(as_tuple=True)[0]  # (c_leaf,)
    perm = torch.cat([nonleaf_indices, leaf_indices])  # non-leaf first
    inv_perm = torch.argsort(perm)
    # Reorder desc columns to match the permuted input order used in propagate_logits
    nonleaf_desc_cc = desc_matrix_cc[nonleaf_indices][:, perm]  # (c_nonleaf, c)
    return {"nonleaf_desc_cc": nonleaf_desc_cc, "perm": perm, "inv_perm": inv_perm}


@torch.compile()
def propagate_probs(probs_nc: torch.Tensor, descendant_tensor_cc: torch.Tensor) -> torch.Tensor:
    """
    Propagate probabilities up the hierarchy defined by ``descendant_tensor_cc`` using matrix multiplication.
    This effectively sums the probabilities of all descendant categories for each category.
    The output is then clamped to a maximum of 1.0 to ensure valid probability values.

    Args:
        probs_nc: Tensor of shape (n, c) containing the probabilities for each category.
        descendant_tensor_cc: Binary tensor of shape (c, c) defining descendant relationships.

    Returns:
        Tensor of shape (n, c) containing the propagated probabilities for each category
    """
    propagated_probs_nc = torch.einsum(
        "nc,kc->nk",
        probs_nc,
        descendant_tensor_cc,
    )
    return torch.clamp(propagated_probs_nc, max=1.0)


def _logsumexp_propagated(logits_nc: torch.Tensor, desc_matrix_cc: torch.Tensor) -> torch.Tensor:
    temp = torch.where(desc_matrix_cc.T == 0, float("-inf"), logits_nc.unsqueeze(dim=-1) * desc_matrix_cc.T)
    return temp.logsumexp(dim=1)


@torch.compile()
def propagate_logits(
    logits_nc: torch.Tensor,
    nonleaf_desc_cc: torch.Tensor,
    perm: torch.Tensor,
    inv_perm: torch.Tensor,
) -> torch.Tensor:
    """
    Perform probability propagation in logit space.

    Non-leaf output categories reduce over all their descendants via
    ``_logsumexp_propagated`` using a ``(c_nonleaf, c)`` submatrix, so the
    intermediate tensor is ``(n, c, c_nonleaf)`` rather than ``(n, c, c)``.
    Leaf output categories are the identity (logsumexp of a single element).

    ``perm`` / ``inv_perm`` sort columns so non-leaf outputs come first,
    allowing assembly via ``torch.cat`` and a single integer-index gather —
    avoiding ``aten.nonzero`` which breaks ``torch.compile``.

    Args:
        logits_nc: ``(n, c)`` raw logit tensor.
        nonleaf_desc_cc: ``(c_nonleaf, c)`` descendant rows for non-leaf outputs,
            with columns in ``perm`` order (from ``_build_nonleaf_info``).
        perm: ``(c,)`` permutation — non-leaf indices first, leaf last
            (from ``_build_nonleaf_info``).
        inv_perm: ``(c,)`` inverse of ``perm`` (from ``_build_nonleaf_info``).
    Returns:
        ``(n, c)`` propagated log-probability tensor in original column order.
    """
    c_nonleaf = nonleaf_desc_cc.shape[0]
    logits_reordered = logits_nc[:, perm]  # (n, c): non-leaf first
    nonleaf_part = _logsumexp_propagated(logits_reordered, nonleaf_desc_cc)  # (n, c_nonleaf)
    leaf_part = logits_reordered[:, c_nonleaf:]  # (n, c_leaf)
    out = torch.cat([nonleaf_part, leaf_part], dim=1)[:, inv_perm]  # (n, c) original order
    return out - torch.logsumexp(logits_nc, dim=1, keepdim=True)


[docs] class SOCAM(CellariumModel, PredictMixin, ValidateMixin): """ Logistic regression model for cell type ontology classification. Args: n_obs: Number of observations in the dataset (used to scale the cross-entropy loss). var_names_g: The variable-name schema for the input data; used for validation. output_categories: Total number of target categories expected at prediction/validation time. Used when the trained model has fewer categories than the final output space. descendant_tensor: Binary (0/1) tensor of shape (n_categories, n_categories) defining the descendant relationships between categories. Row i contains ones for all categories considered descendants of category i (plus self). Used for probability-propagation. cl_names: Full list of category identifiers matching the rows/columns of ``descendant_tensor``. cl_name_subset: Optional list of category names (from ``cl_names``) to restrict training and prediction to. The list is sorted internally so order does not matter. When ``None``, all categories are used. probability_propagation_flag: If True, applies hierarchical probability propagation before predicting the output distribution. W_prior_scale: Scale (b) parameter of the Laplace prior on the weight matrix `W_gc`. W_init_scale: Standard deviation for initializing `W_gc`. seed: Random seed used to initialize parameters. log_metrics: If True, logs weight histograms (TensorBoard) during training. If True, logs weight histograms (TensorBoard) during training. """ def __init__( self, n_obs: int, var_names_g: np.ndarray, descendant_tensor: torch.Tensor, cl_names: list[str], cl_name_subset: list[str] | None = None, probability_propagation_flag: bool = True, W_prior_scale: float = 1e-2, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True, include_ancestors_of_cl_name_subset: bool = True, ) -> None: super().__init__() self.n_obs = n_obs self.var_names_g = var_names_g self.n_vars = len(var_names_g) self.cl_names = cl_names descendant_tensor = descendant_tensor.float() if descendant_tensor.shape[0] != descendant_tensor.shape[1]: raise ValueError("`descendant_tensor` should be a square matrix.") if descendant_tensor.trace() != descendant_tensor.shape[0]: raise ValueError( "`descendant_tensor` should have ones on the diagonal (each category is a descendant of itself)." ) if len(cl_names) != descendant_tensor.shape[0]: raise ValueError("Length of `cl_names` should match the number of rows in `descendant_tensor`.") self._descendant_tensor = descendant_tensor self.register_buffer("descendant_tensor", descendant_tensor) self.n_categories = descendant_tensor.shape[0] self.include_ancestors_of_cl_name_subset = include_ancestors_of_cl_name_subset if include_ancestors_of_cl_name_subset and cl_name_subset is not None: cl_name_subset = _expand_with_ancestors(cl_name_subset, cl_names, descendant_tensor) self.cl_name_subset = cl_name_subset self.probability_propagation_flag = probability_propagation_flag self.seed = seed self.log_metrics = log_metrics # Build active category set (subset if given, else full list) active_cl_names: list[str] = cl_name_subset if cl_name_subset is not None else list(cl_names) self.active_cl_names = active_cl_names self.n_active_cats = len(active_cl_names) self.label_lookup: dict[str, int] = {name: i for i, name in enumerate(active_cl_names)} # Precompute the active-category submatrix of the descendant tensor. # Use the raw (pre-registration) descendant_tensor so this works even when # __init__ is called inside a meta-device context (register_buffer would # create a meta tensor, but descendant_tensor here is still a real CPU tensor). index_map_init = {cat: i for i, cat in enumerate(cl_names)} active_indices = [index_map_init[cat] for cat in active_cl_names] ix = torch.tensor(active_indices, dtype=torch.long) active_descendant_tensor_cc = descendant_tensor[ix][:, ix] nonleaf_info = _build_nonleaf_info(active_descendant_tensor_cc) # Store CPU copies so reset_parameters() can repopulate after meta→real copy. self._active_descendant_tensor_cc = active_descendant_tensor_cc self._nonleaf_desc_cc = nonleaf_info["nonleaf_desc_cc"] self._perm = nonleaf_info["perm"] self._inv_perm = nonleaf_info["inv_perm"] self.register_buffer("active_descendant_tensor_cc", active_descendant_tensor_cc) self.register_buffer("nonleaf_desc_cc", nonleaf_info["nonleaf_desc_cc"]) self.register_buffer("perm", nonleaf_info["perm"]) self.register_buffer("inv_perm", nonleaf_info["inv_perm"]) # Trainable parameters — sized to active categories only self._W_prior_scale = W_prior_scale self.W_init_scale = W_init_scale self.W_prior_scale: torch.Tensor self.register_buffer("W_prior_scale", torch.empty(())) self.W_gc = torch.nn.Parameter(torch.empty(self.n_vars, self.n_active_cats, dtype=torch.float)) self.b_c = torch.nn.Parameter(torch.empty(self.n_active_cats, dtype=torch.float)) self.reset_parameters() def reset_parameters(self) -> None: rng_device = self.W_gc.device.type if self.W_gc.device.type != "meta" else "cpu" rng = torch.Generator(device=rng_device) rng.manual_seed(self.seed) self.W_prior_scale.fill_(self._W_prior_scale) self.descendant_tensor.copy_(self._descendant_tensor) self.active_descendant_tensor_cc.copy_(self._active_descendant_tensor_cc) self.nonleaf_desc_cc.copy_(self._nonleaf_desc_cc) self.perm.copy_(self._perm) self.inv_perm.copy_(self._inv_perm) self.W_gc.data.normal_(0, self.W_init_scale, generator=rng) self.b_c.data.zero_() def _cl_names_to_indices(self, cl_names_n: np.ndarray) -> torch.Tensor: """ Convert a per-cell array of string category names to a 1-D integer tensor of 0-based indices into the active category list. Args: cl_names_n: Array of length n containing category name strings for each cell. Returns: Long tensor of shape ``(n,)`` with integer category indices. Raises: ValueError: If any label in ``cl_names_n`` is not present in ``self.label_lookup``. """ try: return torch.tensor([self.label_lookup[c] for c in cl_names_n], dtype=torch.long) except KeyError as exc: valid = sorted(self.label_lookup.keys()) raise ValueError(f"Label {exc} is not in the active category list. Valid labels are: {valid}") from exc
[docs] def forward( self, x_ng: torch.Tensor, var_names_g: np.ndarray, cl_names_n: np.ndarray, ) -> dict[str, torch.Tensor | None]: """ Args: x_ng: The input data. var_names_g: The variable names for the input data. cl_names_n: Array of length n containing a category name string (from ``self.cl_names``) for each cell. When ``self.cl_name_subset`` is set, every label must be a member of that subset. Returns: A dictionary with the loss value. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "self.var_names_g", self.var_names_g) y_n = self._cl_names_to_indices(cl_names_n).to(x_ng.device) logits_nc = self._compute_regression(x_ng, self.W_gc, self.b_c) if self.probability_propagation_flag: logits_nc = propagate_logits(logits_nc, self.nonleaf_desc_cc, self.perm, self.inv_perm) scale = self.n_obs / x_ng.shape[0] ce_loss = torch.nn.functional.cross_entropy(logits_nc, y_n, reduction="sum") * scale laplace_loss = self.W_gc.abs().sum() / self.W_prior_scale loss = ce_loss + laplace_loss return {"loss": loss}
@torch.compile() def _compute_regression(self, x_ng: torch.Tensor, W_gc: torch.Tensor, b_c: torch.Tensor) -> torch.Tensor: return x_ng @ W_gc + b_c
[docs] def validate( self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch_idx: int, x_ng: torch.Tensor, var_names_g: np.ndarray, cl_names_n: np.ndarray, ) -> None: """ Default validation method for models. This method logs the validation loss to TensorBoard. Override this method to customize the validation behavior. """ output = self( x_ng=x_ng, var_names_g=var_names_g, cl_names_n=cl_names_n, ) loss = output.get("loss") if loss is not None: # Logging to TensorBoard by default pl_module.log("val_loss", loss, sync_dist=True, on_epoch=True)
[docs] def predict( self, x_ng: torch.Tensor, var_names_g: np.ndarray, ) -> dict[str, np.ndarray | torch.Tensor]: """ Predict the target logits. Args: x_ng: The input data. var_names_g: The variable names for the input data. Returns: A dictionary with the target logits. Output tensors have shape ``(n, n_active_cats)``. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "self.var_names_g", self.var_names_g) logits_nc = self._compute_regression(x_ng, self.W_gc, self.b_c) probs_nc = torch.nn.functional.softmax(logits_nc, dim=1) if self.probability_propagation_flag: probs_nc = propagate_probs(probs_nc, self.active_descendant_tensor_cc) return {"y_logits_nc": logits_nc, "cell_type_probs_nc": probs_nc}
def on_train_epoch_end(self, trainer: pl.Trainer) -> None: if trainer.global_rank != 0: return if not self.log_metrics: return for logger in trainer.loggers: if isinstance(logger, pl.loggers.TensorBoardLogger): try: logger.experiment.add_histogram( "W_gc", self.W_gc, global_step=trainer.global_step, ) except ValueError as e: warnings.warn(f"Failed to log histogram for W_gc step={trainer.global_step} due to {e}")