Source code for cellarium.ml.models.scvi

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

"""Flexible modified version of single-cell variational inference (scVI) re-implemented in Cellarium ML."""

import importlib
import itertools
import logging
from abc import abstractmethod
from typing import Any, Literal, Sequence

import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch
import zuko.flows
from anndata import AnnData
from torch.distributions import Distribution, Normal, Poisson
from torch.distributions import kl_divergence as kl

from cellarium.ml.distributions import NegativeBinomial
from cellarium.ml.layers import DressedLayer, FullyConnectedLinear
from cellarium.ml.models.model import CellariumModel, PredictMixin, ValidateMixin
from cellarium.ml.utilities.data import categories_to_product_codes
from cellarium.ml.utilities.testing import (
    assert_arrays_equal,
    assert_columns_and_array_lengths_equal,
)

logger = logging.getLogger(__name__)


def class_from_class_path(class_path: str):
    module_name, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    class_ref = getattr(module, class_name)
    return class_ref


def instantiate_from_class_path(class_path, *args, **kwargs):
    class_ = class_from_class_path(class_path)
    return class_(*args, **kwargs)


def weights_init(m):
    if isinstance(m, torch.nn.BatchNorm1d):
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)
    elif isinstance(m, torch.nn.Linear) or isinstance(m, LinearWithStructuredBias):
        torch.nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)


class LinearWithStructuredBias(torch.nn.Linear):
    """A `torch.nn.Linear` layer where batch indices are given as input to the forward pass.

    Args:
        in_features: passed to `torch.nn.Linear`
        out_features: passed to `torch.nn.Linear`
        n_batch: the dimensionality of the batch representation
        categorical_covariate_dimensions: a list of integers containing the number of categories
            for each categorical covariate
        label_to_bias_hidden_layers: a list of hidden layer sizes for the label-to-bias decoder
        bias: passed to `torch.nn.Linear` (True is like the scvi-tools implementation)
        label_to_bias_dressing_init_kwargs: a dictionary of keyword arguments to pass to
            the `DressedLayer` constructor
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        label_to_bias_hidden_layers: list[int],
        n_batch: int = 0,
        categorical_covariate_dimensions: list[int] | None = None,
        bias: bool = True,
        label_to_bias_dressing_init_kwargs: dict[str, Any] | None = None,
    ):
        super().__init__(in_features, out_features, bias=bias)
        if categorical_covariate_dimensions is None:
            categorical_covariate_dimensions = []
        if label_to_bias_dressing_init_kwargs is None:
            label_to_bias_dressing_init_kwargs = {}
        if n_batch + sum(categorical_covariate_dimensions) < 1:
            raise ValueError("in_features=0: at least one batch or categorical covariate dimension must be provided.")
        self.bias_decoder = FullyConnectedLinear(
            in_features=n_batch + sum(categorical_covariate_dimensions),
            out_features=out_features,
            n_hidden=label_to_bias_hidden_layers,
            dressing_init_kwargs=label_to_bias_dressing_init_kwargs,
        )

    @abstractmethod
    def compute_bias(
        self,
        batch_nb: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Returns the bias given batch representations.

        Args:
            batch_nb: a tensor of batch representations (could be one-hot) of shape (n, batch_latent_dim)
            categorical_covariate_np: a tensor of categorical covariates of shape (n, sum(n_categories_per_covariate))

        Returns:
            a tensor of shape (n, out_features)
        """
        pass

    def forward(  # type: ignore[override]
        self,
        x_ng: torch.Tensor,
        batch_nb: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """
        Computes the forward pass of the layer as
        out = x @ self.weight.T + self.bias + bias

        where bias is computed as
        bias = bias_encoder(batch)

        Args:
            x_ng: a tensor of shape (n, in_features)
            batch_nb: a tensor of batch indices of shape (n, batch_latent_dim)
            categorical_covariate_np: a tensor of categorical covariates of shape (n, sum(n_categories_per_covariate))
        """
        return super().forward(x_ng) + self.compute_bias(
            batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np
        )


class LinearWithBatchAndCovariates(LinearWithStructuredBias):
    def compute_bias(
        self,
        batch_nb: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if categorical_covariate_np is None:
            raise ValueError("Categorical covariates must be provided to LinearWithBatchAndCovariates")
        else:
            return self.bias_decoder(torch.cat([batch_nb, categorical_covariate_np], dim=-1))


class LinearWithBatch(LinearWithStructuredBias):
    def compute_bias(
        self,
        batch_nb: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return self.bias_decoder(batch_nb)


class LinearWithCovariates(LinearWithStructuredBias):
    def compute_bias(
        self,
        batch_nb: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None = None,
    ) -> torch.Tensor:
        if categorical_covariate_np is None:
            raise ValueError("Categorical covariates must be provided to LinearWithCovariates")
        else:
            return self.bias_decoder(categorical_covariate_np)


class FullyConnectedWithBatchArchitecture(torch.nn.Module):
    """
    Fully connected block of layers (can be empty) that can include LinearWithBatch layers.
    The forward pass takes per-cell batches.

    Args:
        in_features: The dimensionality of the input
        layers: A list of dictionaries, each containing the following keys:
            * ``class_path``: the class path of the layer to use
            * ``init_args``: a dictionary of keyword arguments to pass to the layer's constructor
                - must contain "out_features"
    """

    def __init__(
        self,
        in_features: int,
        layers: list[dict],
    ):
        super().__init__()
        for layer in layers:
            assert "out_features" in layer["init_args"], """
            "out_features" must be specified in init_args for hidden layers, e.g.

            - class_path: cellarium.ml.models.scvi.LinearWithBatch
              init_args:
                out_features: 128
            """

        if len(layers) == 0:
            module_list = torch.nn.ModuleList([torch.nn.Identity()])
            out_features = in_features
        else:
            module_list = torch.nn.ModuleList()
            n_hidden = [layer["init_args"].get("out_features") for layer in layers]
            for layer, n_in, n_out in zip(layers, [in_features] + n_hidden, n_hidden):
                layer["init_args"]["out_features"] = n_out
                module_list.append(
                    DressedLayer(
                        instantiate_from_class_path(
                            layer["class_path"],
                            in_features=n_in,
                            bias=True,
                            **layer["init_args"],
                        ),
                        **layer["dressing_init_args"],
                    )
                )
            assert hasattr(module_list[-1].layer, "out_features") and isinstance(
                module_list[-1].layer.out_features, int
            )
            out_features = module_list[-1].layer.out_features
        self.module_list = module_list
        self.out_features = out_features

    def forward(
        self,
        x_ng: torch.Tensor,
        batch_nb: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None,
    ) -> torch.Tensor:
        x_ = x_ng
        for dressed_layer in self.module_list:
            x_ = (
                dressed_layer(x_, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
                if (hasattr(dressed_layer, "layer") and isinstance(dressed_layer.layer, LinearWithStructuredBias))
                else dressed_layer(x_)
            )
        return x_


class EncoderSCVI(torch.nn.Module):
    """
    Encode data of ``in_features`` dimensions into a latent space of ``out_features`` dimensions.

    Args:
        in_features: The dimensionality of the input (data space)
        out_features: The dimensionality of the output (latent space)
        hidden_layers: A list of dictionaries, each containing the following keys:
            * ``class_path``: the class path of the layer to use
            * ``init_args``: a dictionary of keyword arguments to pass to the layer's constructor
                - must contain "out_features"
        final_layer: Same as hidden_layers, but for the final layer
        output_bias: If True, the output layer will have a batch-specific bias added
            (scvi-tools does not include this)
        var_eps: Minimum value for the variance; used for numerical stability
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_layers: list[dict],
        final_layer: dict,
        var_eps: float = 1e-4,
    ):
        super().__init__()
        self.fully_connected = FullyConnectedWithBatchArchitecture(in_features, hidden_layers)
        self.mean_encoder = instantiate_from_class_path(
            final_layer["class_path"],
            in_features=self.fully_connected.out_features,
            out_features=out_features,
            bias=final_layer["init_args"].get("bias", True),
            **final_layer["init_args"],
        )
        self.var_encoder = instantiate_from_class_path(
            final_layer["class_path"],
            in_features=self.fully_connected.out_features,
            out_features=out_features,
            bias=final_layer["init_args"].get("bias", True),
            **final_layer["init_args"],
        )
        self.mean_encoder_injects_covariates = isinstance(self.mean_encoder, LinearWithStructuredBias)
        self.var_eps = var_eps

    def forward(
        self,
        x_ng: torch.Tensor,
        batch_nb: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None,
    ) -> Distribution:
        q_nh = self.fully_connected(x_ng, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
        q_mean_nk = (
            self.mean_encoder(q_nh, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
            if self.mean_encoder_injects_covariates
            else self.mean_encoder(q_nh)
        )
        q_var_nk = (
            torch.exp(
                self.var_encoder(q_nh, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
                if self.mean_encoder_injects_covariates
                else self.var_encoder(q_nh)
            )
            + self.var_eps
        )
        return Normal(q_mean_nk, q_var_nk.sqrt())


class DecoderSCVI(torch.nn.Module):
    """
    Decode data of ``in_features`` latent dimensions into data space of ``out_features`` dimensions.

    Args:
        in_features: The dimensionality of the input (latent space)
        out_features: The dimensionality of the output (data space)
        hidden_layers: A list of dictionaries, each containing the following keys:
            * ``class_path``: the class path of the layer to use
            * ``init_args``: a dictionary of keyword arguments to pass to the layer's constructor
                - must contain "out_features"
        final_layer: Same as hidden_layers, but for the final layer
        dispersion: Granularity at which the overdispersion of the negative binomial distribution is computed
        gene_likelihood: Distribution to use for reconstruction in the generative process
        scale_activation: Activation layer to use to compute normalized counts (before multiplying by library size)
        final_additive_bias: If True, the final layer will have a batch-specific bias added after the activation.
            If final_layer is a LinearWithBatch layer and final_additive_bias is True, the last layer of the decoder
            will act as a batch-specific affine transformation.
        eps: Numerical stability factor added to mean and inverse overdispersion of negative binomial
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_layers: list[dict],
        final_layer: dict,
        n_batch: int,
        dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
        gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb",
        scale_activation: Literal["softmax", "softplus"] = "softmax",
        final_additive_bias: bool = False,
        n_cats_per_cov: list[int] | None = None,
        eps: float = 1e-10,
    ):
        super().__init__()
        self.eps = eps
        self.n_batch = n_batch
        self.n_cats_per_cov = n_cats_per_cov
        if gene_likelihood == "zinb":
            raise NotImplementedError("Zero-inflated negative binomial not yet implemented")
        self.gene_likelihood = gene_likelihood
        self.fully_connected = FullyConnectedWithBatchArchitecture(in_features, hidden_layers)
        self.inverse_overdispersion_decoder = (
            torch.nn.Linear(self.fully_connected.out_features, out_features)
            if ((gene_likelihood != "poisson") and (dispersion == "gene-cell"))
            else None
        )
        self.dropout_decoder = (
            torch.nn.Linear(self.fully_connected.out_features, out_features) if (gene_likelihood == "zinb") else None
        )
        final_layer_init_args = final_layer["init_args"]
        self.normalized_count_decoder = instantiate_from_class_path(
            final_layer["class_path"],
            in_features=self.fully_connected.out_features,
            out_features=out_features,
            bias=final_layer_init_args.get("bias", True),
            **{k: v for k, v in final_layer_init_args.items() if k != "bias"},
        )
        self.count_decoder_takes_batch = isinstance(self.normalized_count_decoder, LinearWithStructuredBias)
        self.normalized_count_activation = (
            torch.nn.Softmax(dim=-1) if (scale_activation == "softmax") else torch.nn.Softplus()
        )
        self.final_additive_bias = final_additive_bias
        if self.n_cats_per_cov is None:
            categorical_features = 0
        else:
            categorical_features = sum(self.n_cats_per_cov)
        self.final_additive_bias_layer: torch.nn.Sequential | None = None
        if self.final_additive_bias:
            self.final_additive_bias_layer = torch.nn.Sequential(
                FullyConnectedLinear(
                    in_features=self.n_batch + categorical_features,
                    out_features=out_features,
                    n_hidden=[],
                    dressing_init_kwargs={},
                ),
                torch.nn.ReLU(),
            )

    def forward(
        self,
        z_nk: torch.Tensor,
        batch_nb: torch.Tensor,
        inverse_overdispersion: torch.Tensor | None,
        library_size_n1: torch.Tensor,
        categorical_covariate_np: torch.Tensor | None = None,
    ) -> Distribution:
        # bulk of the network
        q_nh = self.fully_connected(
            z_nk,
            batch_nb=batch_nb,
            categorical_covariate_np=categorical_covariate_np,
        )

        # mean counts
        unnormalized_chi_ng = (
            self.normalized_count_decoder(q_nh, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
            if self.count_decoder_takes_batch
            else self.normalized_count_decoder(q_nh)
        )
        chi_ng = self.normalized_count_activation(unnormalized_chi_ng)
        if self.final_additive_bias_layer is not None:
            count_mean_ng = torch.exp(library_size_n1) * chi_ng + self.final_additive_bias_layer(
                torch.cat([batch_nb, categorical_covariate_np], dim=-1)
                if categorical_covariate_np is not None
                else batch_nb
            )
        else:
            count_mean_ng = torch.exp(library_size_n1) * chi_ng

        # construct the count distribution
        dist: Distribution
        match self.gene_likelihood:
            case "nb":
                if inverse_overdispersion is None:
                    assert self.inverse_overdispersion_decoder is not None, (
                        "inverse_overdispersion must be provided when not using Poisson or gene-cell dispersion"
                    )
                    inverse_overdispersion = self.inverse_overdispersion_decoder(q_nh).exp()
                dist = NegativeBinomial(count_mean_ng + self.eps, inverse_overdispersion + self.eps)
            case "poisson":
                dist = Poisson(count_mean_ng + self.eps)
            case "zinb":
                raise NotImplementedError("ZINB is not currently implemented")
                # dist = ZeroInflatedNegativeBinomial(count_mean_ng, inverse_overdispersion, self.dropout_decoder(q_nh))

        return dist


def compute_annealed_kl_weight(
    epoch: int,
    step: int,
    n_epochs_kl_warmup: int | None,
    n_steps_kl_warmup: int | None,
    max_kl_weight: float = 1.0,
    min_kl_weight: float = 0.0,
) -> float:
    """Computes the kl weight for the current step or epoch.
    If both `n_epochs_kl_warmup` and `n_steps_kl_warmup` are None `max_kl_weight` is returned.
    Args:
        epoch: Current epoch.
        step: Current step.
        n_epochs_kl_warmup: Number of epochs to warm up the KL weight.
        n_steps_kl_warmup: Number of steps to warm up the KL weight.
        max_kl_weight: Maximum KL weight.
        min_kl_weight: Minimum KL weight.
    Returns:
        The KL weight for the current step or epoch.
    """
    if min_kl_weight > max_kl_weight:
        raise ValueError(f"min_kl_weight={min_kl_weight} is larger than max_kl_weight={max_kl_weight}.")

    slope = max_kl_weight - min_kl_weight
    if n_epochs_kl_warmup:
        if epoch < n_epochs_kl_warmup:
            return slope * (epoch / n_epochs_kl_warmup) + min_kl_weight
    elif n_steps_kl_warmup:
        if step < n_steps_kl_warmup:
            return slope * (step / n_steps_kl_warmup) + min_kl_weight
    return max_kl_weight


[docs] class SingleCellVariationalInference(CellariumModel, PredictMixin, ValidateMixin): """ Flexible version of single-cell variational inference (scVI) [1] re-implemented in Cellarium ML. **References:** 1. `Deep generative modeling for single-cell transcriptomics (Lopez et al.) <https://www.nature.com/articles/s41592-018-0229-2>`_. Args: var_names_g: The variable names schema for the input data validation. encoder: Dict specifying the encoder configuration. decoder: Dict specifying the decoder configuration. n_latent: Dimension of the latent space. n_batch: Number of total batches in the dataset. batch_representation_sampled: True to sample latent batch from a distribution. n_continuous_cov: Number of continuous covariates. n_cats_per_cov: A list of integers containing the number of categories for each categorical covariate. dropout_rate: Dropout rate for hidden units in the encoder only. input_gene_dropout_rate: Gene dropout rate for input data that goes into the encoder during training. dispersion: Flexibility of the dispersion parameter when ``gene_likelihood`` is either ``"nb"`` or ``"zinb"``. One of the following: * ``"gene"``: parameter is constant per gene across cells. * ``"gene-batch"``: parameter is constant per gene per batch. * ``"gene-label"``: parameter is constant per gene per label. * ``"gene-cell"``: parameter is constant per gene per cell. log_variational: If ``True``, use :func:`~torch.log1p` on input data before encoding for numerical stability (not normalization). gene_likelihood: Distribution to use for reconstruction in the generative process. One of the following: * ``"nb"``: :class:`~scvi.distributions.NegativeBinomial`. * ``"zinb"``: :class:`~scvi.distributions.ZeroInflatedNegativeBinomial`. (not implemented) * ``"poisson"``: :class:`~scvi.distributions.Poisson`. latent_distribution: Distribution to use for the latent space. One of the following: * ``"normal"``: isotropic normal. * ``"ln"``: logistic normal with normal params N(0, 1). (not implemented) use_batch_norm: Specifies where to use :class:`~torch.nn.BatchNorm1d` in the model. One of the following: * ``"none"``: don't use batch norm in either encoder(s) or decoder. * ``"encoder"``: use batch norm only in the encoder(s). * ``"decoder"``: use batch norm only in the decoder. * ``"both"``: use batch norm in both encoder(s) and decoder. use_layer_norm: Specifies where to use :class:`~torch.nn.LayerNorm` in the model. One of the following: * ``"none"``: don't use layer norm in either encoder(s) or decoder. * ``"encoder"``: use layer norm only in the encoder(s). * ``"decoder"``: use layer norm only in the decoder. * ``"both"``: use layer norm in both encoder(s) and decoder. Note: only one of use_batch_norm or use_layer_norm should be specified. use_size_factor_key: If ``True``, use the :attr:`~anndata.AnnData.obs` column as defined by the ``size_factor_key`` parameter in the model's ``setup_anndata`` method as the scaling factor in the mean of the conditional distribution. If ``False``, the observed library size (log of the sum of counts per cell) is used. Should be ``False``. reconstruct_counts_on_predict: Changes the behavior of :meth:`predict`. True will reconstruct gene expression count data, False will return the latent representations reconstruction_var_names_g: List of var_names to be reconstructed (outputs are dense matrices) reconstruction_transform_batch: None will reconstruct in the original data batch. This is like imputation or smoothing. An integer will reconstruct counts in the specified batch index. The string "mean" will reconstruct counts in the first 10 batches and return the mean. reconstruction_n_latent_samples: Number of latent samples to use for reconstruction. Each latent sample will be used to compute the mean of the generative distribution, and the final output will be the mean of those. reconstruction_use_latent_mean: True to use the mean of the latent distribution rather than sampling. reconstruction_use_importance_sampling: True to use importance sampling weighted by each latent sample's likelihood. reconstructed_library_size: The library size to use for the reconstruction, common to all cells. use_flow: If True, use a Neural Spline Flow (NSF) as the prior on the latent space instead of the standard normal N(0, I). The flow is unconditional (batch-blind) and is jointly trained with the encoder/decoder via an MC-KL estimate: E_q[log q(z|x) - log p_flow(z)]. flow_hidden_features: Hidden layer widths for the NSF. Only used when ``use_flow=True``. cell_type_categories: Ordered list of CL ID strings (e.g. ``["CL:0000540", ...]``) that matches ``adata.obs[cell_type_col].cat.categories`` exactly (same order), so that integer codes from ``.cat.codes`` map directly to rows of the internal distance buffer. Required if ``ontology_distance_matrix`` is provided. ontology_distance_matrix: Square :class:`pandas.DataFrame` with CL ID strings as both index and columns, as returned by :func:`~cellarium.ml.utilities.data.compute_cl_distance_matrix`. The constructor will slice and reorder this to match ``cell_type_categories``. Enables the frequency-weighted Spearman correlation metric (``val_ontology_spearman``) during validation. Not saved to checkpoints. val_cell_type_classifier_reservoir_size: Maximum number of cells to retain per split (train / test) for the logistic regression cell type classifier. Reservoir sampling is used so this bound is respected regardless of validation set size. Train cells are drawn from even-numbered validation batches; test cells from odd-numbered batches. Ignored when ``cell_type_categories`` is not provided. Default 50_000. """ def __init__( self, var_names_g: Sequence[str], encoder: dict[str, list[dict] | dict | bool], decoder: dict[str, list[dict] | dict | bool], n_batch: int = 0, n_latent: int = 10, n_continuous_cov: int = 0, n_cats_per_cov: list[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene", log_variational: bool = True, gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb", latent_distribution: Literal["normal", "ln"] = "normal", batch_embedded: bool = False, batch_representation_sampled: bool = False, n_latent_batch: int | None = None, z_kl_weight_max: float = 1.0, batch_kl_weight_max: float = 0.0, input_gene_dropout_rate: float = 0.0, use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both", use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none", kl_warmup_epochs: int | None = 400, kl_warmup_steps: int | None = None, kl_annealing_start: float = 0.0, use_size_factor_key: bool = False, reconstruct_counts_on_predict: bool = False, reconstruction_var_names_g: np.ndarray | list | None = None, reconstruction_transform_batch: None | int | str = 0, reconstruction_n_latent_samples: int = 30, reconstruction_use_latent_mean: bool = False, reconstruction_use_importance_sampling: bool = False, reconstructed_library_size: int = 10_000, reconstruction_transform_categorical_covariates: list[int] | None = None, use_flow: bool = False, flow_hidden_features: list[int] = [64, 64], cell_type_categories: list[str] | None = None, ontology_distance_matrix: pd.DataFrame | None = None, val_cell_type_classifier_reservoir_size: int = 50_000, ): super().__init__() self.var_names_g = np.array(var_names_g) self.n_input = len(self.var_names_g) self.dispersion = dispersion self.n_latent = n_latent self.n_batch = n_batch self.log_variational = log_variational self.gene_likelihood = gene_likelihood self.input_gene_dropout_rate = input_gene_dropout_rate self.latent_distribution = latent_distribution self.n_cats_per_cov = n_cats_per_cov if n_cats_per_cov is not None else [] self.use_size_factor_key = use_size_factor_key self.batch_embedded = batch_embedded self.batch_representation_sampled = batch_representation_sampled self.n_latent_batch = n_latent_batch if not (kl_annealing_start >= 0.0 and kl_annealing_start <= 1.0): raise ValueError(f"kl_annealing_start={kl_annealing_start} must be in the range [0.0, 1.0].") self.kl_annealing_start = kl_annealing_start assert not ((kl_warmup_steps is not None) and (kl_warmup_epochs is not None)), ( "Only one of kl_warmup_epochs or kl_warmup_steps can be specified, not both." ) self.kl_warmup_epochs = kl_warmup_epochs self.kl_warmup_steps = kl_warmup_steps assert batch_kl_weight_max >= 0.0, "batch_kl_weight must be non-negative" self.batch_kl_weight_max = batch_kl_weight_max assert z_kl_weight_max >= 0.0, "z_kl_weight must be non-negative" self.z_kl_weight_max = z_kl_weight_max self.epoch = 0 self.step = 0 self.reconstruction_var_names_g = reconstruction_var_names_g if reconstruction_var_names_g is not None: allowed_var_names_g = [] model_var_names = set(self.var_names_g) for var_name in reconstruction_var_names_g: if var_name not in model_var_names: import warnings warnings.warn( f"reconstruction_var_names_g {var_name} not found in model var_names_g. " "It will not be included in the reconstructed output. Check " "model.reconstruction_var_names_g for the included var_names.", UserWarning, ) else: allowed_var_names_g.append(var_name) self.reconstruction_var_names_g = allowed_var_names_g else: if reconstruct_counts_on_predict: import warnings warnings.warn( "reconstruction_var_names_g not specified, so all var_names_g will be reconstructed on predict.", UserWarning, ) self.reconstruction_var_names_g = self.var_names_g self.reconstruct_counts_on_predict = reconstruct_counts_on_predict self.reconstruction_transform_batch = reconstruction_transform_batch self.reconstruction_n_latent_samples = reconstruction_n_latent_samples self.reconstruction_use_latent_mean = reconstruction_use_latent_mean self.reconstruction_use_importance_sampling = reconstruction_use_importance_sampling self.reconstructed_library_size = reconstructed_library_size self.reconstruction_transform_categorical_covariates = reconstruction_transform_categorical_covariates self.use_flow = use_flow self.flow_hidden_features = flow_hidden_features # optional validation data metrics setup # ontology_distance_matrix without cell_type_categories: matrix is silently ignored # (can happen when the CLI links cell_type_categories from data but the batch key is absent) if cell_type_categories is None and ontology_distance_matrix is not None: logger.warning( "ontology_distance_matrix was provided but cell_type_categories is None; " "the matrix will be ignored and ontology-based metrics will be disabled." ) ontology_distance_matrix = None self.cell_type_categories = list(cell_type_categories) if cell_type_categories is not None else None self.num_classes = len(cell_type_categories) if cell_type_categories is not None else None self.val_cell_type_classifier_reservoir_size = val_cell_type_classifier_reservoir_size if cell_type_categories is not None and ontology_distance_matrix is not None: missing = [c for c in cell_type_categories if c not in ontology_distance_matrix.index] if missing: raise ValueError( f"The following cell type categories are absent from ontology_distance_matrix: {missing}" ) sub = ontology_distance_matrix.loc[cell_type_categories, cell_type_categories] self._ontology_matrix_numpy: np.ndarray | None = sub.to_numpy(dtype=np.float32) self.register_buffer( "ontology_distance_matrix", torch.tensor(self._ontology_matrix_numpy), persistent=False ) else: self._ontology_matrix_numpy = None self.register_buffer("ontology_distance_matrix", None, persistent=False) if n_continuous_cov > 0: raise NotImplementedError("Continuous covariates are not yet implemented") if gene_likelihood == "zinb": raise NotImplementedError("Zero-inflated negative binomial not yet implemented") if latent_distribution == "ln": raise NotImplementedError("Logistic normal latent distribution is not yet implemented") # if you use one-hot and try to specify a different latent batch than n_batch, raise an error if (not self.batch_embedded) and (self.n_latent_batch is not None) and (self.n_latent_batch != self.n_batch): raise ValueError("n_latent_batch must be equal to n_batch if batch_embedded is False") # if n_latent_batch not specified, set it to n_batch if self.n_latent_batch is None: self.n_latent_batch = self.n_batch # same dim as one-hot would be # handle the embedded batch posterior # initialize the means as one-hot, std as 1 (after exp) if not self.batch_embedded: self.batch_representation_mean_bd: torch.nn.Parameter | None = None self.batch_representation_std_unconstrained_bd: torch.nn.Parameter | None = None else: self.batch_representation_mean_bd = torch.nn.Parameter(torch.eye(self.n_batch, self.n_latent_batch)) self.batch_representation_std_unconstrained_bd = torch.nn.Parameter( torch.zeros(self.n_batch, self.n_latent_batch) ) if self.dispersion == "gene": self.px_r = torch.nn.Parameter(torch.randn(self.n_input)) elif self.dispersion == "gene-label": raise NotImplementedError # self.px_r = torch.nn.Parameter(torch.randn(self.n_input, n_labels)) elif self.dispersion == "gene-cell": self.px_r = torch.nn.Parameter(torch.zeros(1)) # dummy else: raise ValueError( "dispersion must be one of ['gene', 'gene-label', 'gene-cell'], but input was {}".format( self.dispersion ) ) use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both" use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both" use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both" use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both" def _fill_in_layer_input_args(layer: dict): if layer["class_path"] == "cellarium.ml.models.scvi.LinearWithBatch": layer["init_args"]["n_batch"] = self.n_latent_batch layer["init_args"]["categorical_covariate_dimensions"] = [] elif layer["class_path"] == "cellarium.ml.models.scvi.LinearWithCovariates": layer["init_args"]["n_batch"] = 0 layer["init_args"]["categorical_covariate_dimensions"] = self.n_cats_per_cov elif layer["class_path"] == "cellarium.ml.models.scvi.LinearWithBatchAndCovariates": layer["init_args"]["n_batch"] = self.n_latent_batch layer["init_args"]["categorical_covariate_dimensions"] = self.n_cats_per_cov def _fill_in_layer_defaults(layer: dict, batch_norm: bool, layer_norm: bool, dropout_p: float): if "dressing_init_args" not in layer: layer["dressing_init_args"] = {} if "use_batch_norm" not in layer["dressing_init_args"]: logger.info(f"use_batch_norm not specified individually in hidden layer, setting to {batch_norm}") layer["dressing_init_args"]["use_batch_norm"] = batch_norm if "use_layer_norm" not in layer["dressing_init_args"]: logger.info(f"use_layer_norm not specified individually in hidden layer, setting to {layer_norm}") layer["dressing_init_args"]["use_layer_norm"] = layer_norm if "dropout_rate" not in layer["dressing_init_args"]: logger.info(f"dropout_rate not specified individually in hidden layer, setting to {dropout_p}") layer["dressing_init_args"]["dropout_rate"] = dropout_p # encoder layers assert isinstance(encoder["hidden_layers"], list), "encoder hidden_layers must be a list" for layer in encoder["hidden_layers"]: _fill_in_layer_input_args(layer) _fill_in_layer_defaults( layer, batch_norm=use_batch_norm_encoder, layer_norm=use_layer_norm_encoder, dropout_p=dropout_rate, ) assert isinstance(encoder["final_layer"], dict) _fill_in_layer_input_args(encoder["final_layer"]) # decoder layers assert isinstance(decoder["hidden_layers"], list), "decoder hidden_layers must be a list" for layer in decoder["hidden_layers"]: _fill_in_layer_input_args(layer) if "dressing_init_args" in layer and "dropout_rate" in layer["dressing_init_args"]: if layer["dressing_init_args"]["dropout_rate"] != 0.0: logger.warning( "Dropout is not supported in the decoder of scVI. " "dropout_rate is being set to 0.0 in all decoder hidden layers." ) layer["dressing_init_args"]["dropout_rate"] = 0.0 # scvi-tools does not use dropout in the decoder _fill_in_layer_defaults( layer, batch_norm=use_batch_norm_decoder, layer_norm=use_layer_norm_decoder, dropout_p=0, ) assert isinstance(decoder["final_layer"], dict) _fill_in_layer_input_args(decoder["final_layer"]) self.z_encoder = EncoderSCVI( in_features=self.n_input, out_features=self.n_latent, hidden_layers=encoder["hidden_layers"], final_layer=encoder["final_layer"], ) assert isinstance(decoder["final_additive_bias"], bool) self.decoder = DecoderSCVI( in_features=self.n_latent, out_features=self.n_input, hidden_layers=decoder["hidden_layers"], final_layer=decoder["final_layer"], dispersion=self.dispersion, gene_likelihood=self.gene_likelihood, scale_activation="softplus" if use_size_factor_key else "softmax", final_additive_bias=decoder["final_additive_bias"], n_batch=self.n_latent_batch, # for the (optional) sizing of the final additive bias layer n_cats_per_cov=self.n_cats_per_cov, # for the (optional) sizing of the final additive bias layer ) if use_flow: # NSF.__init__ calls torch.unique(..., dim=0) which is not supported on Meta tensors. # Lightning/jsonargparse instantiates models under a torch.device("meta") context to # avoid allocating real memory during CLI parsing. Wrapping with torch.device("cpu") # forces all intermediate tensors in the NSF constructor to be concrete CPU tensors. # The Trainer will move the module to the correct device via .to() before training. with torch.device("cpu"): self.flow: zuko.flows.NSF | None = zuko.flows.NSF( features=self.n_latent, context=0, hidden_features=flow_hidden_features ) else: self.flow = None # detect whether the decoder injects categorical covariates at any point self._decoder_uses_categorical_covariates: bool = any( isinstance(m, (LinearWithCovariates, LinearWithBatchAndCovariates)) for m in self.decoder.modules() ) or ( self.decoder.final_additive_bias_layer is not None and len(self.n_cats_per_cov) > 0 and sum(self.n_cats_per_cov) > 0 ) self.reset_parameters() def reset_parameters(self) -> None: for m in self.modules(): m.apply(weights_init) torch.nn.init.normal_(self.px_r, mean=0.0, std=1.0) if self.batch_representation_mean_bd is not None and self.batch_representation_std_unconstrained_bd is not None: assert isinstance(self.n_latent_batch, int) # mypy with torch.no_grad(): self.batch_representation_mean_bd.data.copy_(torch.eye(self.n_batch, self.n_latent_batch)) self.batch_representation_std_unconstrained_bd.data.fill_(0.0) if self.flow is not None: # Reinitialize the NSF with fresh random weights on CPU, then copy to the target device. # This is necessary because configure_model() moves the whole model with to_empty() when # any parameter is on meta device (encoder/decoder params), which leaves NSF parameters # as uninitialized memory even though the NSF was constructed on CPU. weights_init() # doesn't cover NSF's spline parameters, so they would stay uninitialized without this. target_device = next(self.flow.parameters()).device with torch.device("cpu"): fresh_flow = zuko.flows.NSF( features=self.n_latent, context=0, hidden_features=self.flow_hidden_features ) self.flow.load_state_dict({k: v.to(target_device) for k, v in fresh_flow.state_dict().items()}) if self._ontology_matrix_numpy is not None: # Re-register after meta-device materialization: persistent=False buffers are not # restored by the state dict, and to_empty() leaves them as uninitialised storage. self.register_buffer( "ontology_distance_matrix", torch.tensor(self._ontology_matrix_numpy, device=self.px_r.device), persistent=False, ) def batch_embedding_distribution(self, batch_index_n: torch.Tensor) -> Distribution: assert self.batch_representation_mean_bd is not None assert self.batch_representation_std_unconstrained_bd is not None return Normal( self.batch_representation_mean_bd[batch_index_n.long(), :], self.batch_representation_std_unconstrained_bd[batch_index_n.long(), :].exp() + 1e-5, )
[docs] def batch_representation_from_batch_index( self, batch_index_n: torch.Tensor, use_mean_though_sampling: bool = False, ) -> torch.Tensor: """Compute a batch representation from batch indices. If self.batch_embedded is False, the batch representation will be one-hot (like scvi-tools) If self.batch_embedded is True: If self.batch_representation_sampled is True, the batch representation is sampled from a normal distribution If self.batch_representation_sampled is False, the batch representation is a point estimate """ if not self.batch_embedded: batch_nb = torch.nn.functional.one_hot(batch_index_n.squeeze().long(), num_classes=self.n_batch).float() else: if self.batch_representation_sampled: if use_mean_though_sampling: batch_nb = self.batch_embedding_distribution(batch_index_n=batch_index_n).mean else: batch_nb = self.batch_embedding_distribution(batch_index_n=batch_index_n).rsample() else: assert self.batch_representation_mean_bd is not None batch_nb = self.batch_representation_mean_bd[batch_index_n.long(), :] return batch_nb
[docs] def categorical_onehot_from_categorical_index( self, categorical_covariate_index_nd: torch.Tensor | None, ) -> torch.Tensor | None: """Compute one-hot encoding of categorical covariates from integer category indices. Args: categorical_covariate_index_nd: a tensor of shape (n, n_categorical_covariates) """ if categorical_covariate_index_nd is not None: # make the categorical covariates one-hot categorical_covariate_np = torch.cat( [ torch.nn.functional.one_hot(categorical_covariate_index_nd[:, i].long(), num_classes=n_cats).float() for i, n_cats in enumerate(self.n_cats_per_cov) ], dim=1, ) return categorical_covariate_np return None
[docs] def inference( self, x_ng: torch.Tensor, batch_nb: torch.Tensor, continuous_covariates_nc: torch.Tensor | None = None, categorical_covariate_np: torch.Tensor | None = None, ): """ High level inference method. Runs the inference (encoder) model. """ encoder_input_ng = x_ng if self.log_variational: encoder_input_ng = torch.log1p(encoder_input_ng) qz = self.z_encoder(x_ng=encoder_input_ng, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np) z = qz.rsample() if self.use_flow: assert self.flow is not None # for mypy # Build the flow prior distribution for this forward pass. # self.flow() returns a NormalizingFlow with event_shape=(n_latent,) and no context. pz = self.flow() return dict(z=z, qz=qz, pz=pz) return dict(z=z, qz=qz)
[docs] def generative( self, z_nk: torch.Tensor, library_size_n1: torch.Tensor, batch_nb: torch.Tensor, continuous_covariates_nc: torch.Tensor | None = None, categorical_covariate_np: torch.Tensor | None = None, ) -> dict[str, Distribution]: """Runs the generative model.""" inverse_overdispersion: torch.Tensor | None match self.dispersion: case "gene": inverse_overdispersion = self.px_r.exp() case "gene-cell": inverse_overdispersion = None case "gene-batch": inverse_overdispersion = torch.nn.functional.linear( batch_nb, self.px_r, ).exp() case "gene-label": inverse_overdispersion = None raise NotImplementedError count_distribution = self.decoder( z_nk=z_nk, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np, inverse_overdispersion=inverse_overdispersion, library_size_n1=library_size_n1, ) # prior on latent z pz = Normal(torch.zeros_like(z_nk), torch.ones_like(z_nk)) return dict(px=count_distribution, pz=pz)
[docs] def forward( self, x_ng: torch.Tensor, var_names_g: np.ndarray, batch_index_n: torch.Tensor, continuous_covariates_nc: torch.Tensor | None = None, categorical_covariate_index_nd: torch.Tensor | None = None, total_mrna_umis_n: torch.Tensor | None = None, ): """ Args: x_ng: Gene counts matrix. var_names_g: The list of the variable names in the input data. batch_index_n: Batch indices of input cells as integers. continuous_covariates_nc: Continuous covariates for each cell (c-dimensional). categorical_covariate_index_nd: Categorical covariates for each cell (d-dimensional). Integer membership categorical codes. total_mrna_umis_n: Total mRNA UMIs for each cell (not log scaled) if this should be used. Returns: A dictionary with keys: - "loss": The total loss value. - "reconstruction_loss": The reconstruction loss value. - "kl_divergence_z": The KL divergence for the latent variable z. - "z_nk": The latent variable z. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) batch_nb = self.batch_representation_from_batch_index(batch_index_n) categorical_covariate_np = self.categorical_onehot_from_categorical_index(categorical_covariate_index_nd) if self.use_size_factor_key: assert total_mrna_umis_n is not None, "total_mrna_umis_n must be provided when use_size_factor_key=True" library_size_n1 = torch.log(total_mrna_umis_n).unsqueeze(-1) else: library_size_n1 = torch.log(x_ng.sum(dim=-1, keepdim=True)) if self.input_gene_dropout_rate > 0.0: # randomly drop out some genes in the input data to the encoder during training with torch.no_grad(): dropout_mask_ng = torch.rand_like(x_ng) > self.input_gene_dropout_rate inference_input_x_ng = x_ng * dropout_mask_ng else: inference_input_x_ng = x_ng inference_outputs = self.inference( x_ng=inference_input_x_ng, batch_nb=batch_nb, continuous_covariates_nc=continuous_covariates_nc, categorical_covariate_np=categorical_covariate_np, ) generative_outputs = self.generative( z_nk=inference_outputs["z"], library_size_n1=library_size_n1, batch_nb=batch_nb, continuous_covariates_nc=continuous_covariates_nc, categorical_covariate_np=categorical_covariate_np, ) kl_annealing_weight = compute_annealed_kl_weight( epoch=self.epoch, step=self.step, n_epochs_kl_warmup=self.kl_warmup_epochs, n_steps_kl_warmup=self.kl_warmup_steps, max_kl_weight=1.0, min_kl_weight=self.kl_annealing_start, ) # KL divergence for z if self.use_flow: # MC estimate: KL(q(z|x) || p_flow(z)) = E_q[log q(z) - log p_flow(z)] # qz is Normal with per-dim log_probs; sum over latent dim to get per-cell scalar. # self.flow() (NormalizingFlow) already reduces the event dimension, giving shape (n,). z_nk = inference_outputs["z"] log_qz_n = inference_outputs["qz"].log_prob(z_nk).sum(dim=-1) # nan_to_num guards against spline instability producing NaN / ±inf in log_pz. log_pz_n = torch.nan_to_num(inference_outputs["pz"].log_prob(z_nk), nan=0.0, posinf=0.0, neginf=-1e4) kl_divergence_z_n = log_qz_n - log_pz_n else: kl_divergence_z_n = kl(inference_outputs["qz"], generative_outputs["pz"]).sum(dim=1) # optional KL divergence for batch representation kl_divergence_batch_n: torch.Tensor | int if self.batch_representation_sampled and (self.batch_kl_weight_max > 0): kl_divergence_batch_n = kl( self.batch_embedding_distribution(batch_index_n=batch_index_n), Normal(torch.zeros_like(batch_nb), torch.ones_like(batch_nb)), ).sum(dim=1) else: kl_divergence_batch_n = 0 # reconstruction loss rec_loss_n = -generative_outputs["px"].log_prob(x_ng).sum(-1) # full loss assert kl_annealing_weight >= 0.0 and kl_annealing_weight <= 1.0, ( f"Invalid KL annealing weight: {kl_annealing_weight}" ) loss = torch.mean( rec_loss_n + kl_annealing_weight * (self.z_kl_weight_max * kl_divergence_z_n + self.batch_kl_weight_max * kl_divergence_batch_n), dim=0, ) return { "loss": loss, "reconstruction_loss": rec_loss_n, "kl_divergence_z": kl_divergence_z_n, "kl_divergence_batch": kl_divergence_batch_n, "z_nk": inference_outputs["z"], }
def _latent_value_from_latent_distribution(self, d: Distribution) -> torch.Tensor: """Get the latent variable from the latent distribution.""" return d.mean
[docs] def predict( self, x_ng: torch.Tensor, var_names_g: np.ndarray, batch_index_n: torch.Tensor, continuous_covariates_nc: torch.Tensor | None = None, categorical_covariate_index_nd: torch.Tensor | None = None, ): """ Args: x_ng: Gene counts matrix. var_names_g: The list of the variable names in the input data. batch_index_n: Batch indices of input cells as integers. continuous_covariates_nc: Continuous covariates for each cell (c-dimensional). categorical_covariate_index_nd: Categorical covariates for each cell (d-dimensional where d is number of categorical variables). Values are integer membership categorical codes. Returns: A dictionary with the following keys: - ``x_ng``: - If :attr:`self.reconstruct_counts_on_predict` is False: - (x_ng is a notational misnomer) Embedding of the input data into the scVI latent space, typically referred to as ``z_nk``. - If :attr:`self.reconstruct_counts_on_predict` is True: - (x_ng is a notational misnomer) Reconstruction of the input data: ``x_ng'``, which may have a different number of genes depending on :attr:`self.reconstruction_var_names_g`. """ if self.reconstruct_counts_on_predict: assert self.reconstruction_var_names_g is not None try: gene_inds = [np.where(var_names_g == gid)[0][0] for gid in self.reconstruction_var_names_g] except IndexError: raise ValueError( f"Some genes to reconstruct ({len(set(self.reconstruction_var_names_g) - set(var_names_g))}) " f"are missing from the input data: {set(self.reconstruction_var_names_g) - set(var_names_g)}" ) return self.reconstruct( x_ng=x_ng, var_names_g=var_names_g, gene_inds=gene_inds, batch_index_n=batch_index_n, continuous_covariates_nc=continuous_covariates_nc, categorical_covariate_index_nd=categorical_covariate_index_nd, transform_batch=self.reconstruction_transform_batch, transform_categorical_covariates=self.reconstruction_transform_categorical_covariates, n_latent_samples=self.reconstruction_n_latent_samples, use_importance_sampling=self.reconstruction_use_importance_sampling, use_latent_mean=self.reconstruction_use_latent_mean, reconstructed_library_size=self.reconstructed_library_size, ) assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) batch_nb = self.batch_representation_from_batch_index(batch_index_n) categorical_covariate_np = self.categorical_onehot_from_categorical_index(categorical_covariate_index_nd) z_nk = self._latent_value_from_latent_distribution( self.inference( x_ng=x_ng, batch_nb=batch_nb, continuous_covariates_nc=continuous_covariates_nc, categorical_covariate_np=categorical_covariate_np, )["qz"] ) return {"x_ng": z_nk}
[docs] @torch.no_grad() def reconstruct( self, x_ng: torch.Tensor, var_names_g: np.ndarray, gene_inds: np.ndarray | list[int], batch_index_n: torch.Tensor, continuous_covariates_nc: torch.Tensor | None = None, categorical_covariate_index_nd: torch.Tensor | None = None, transform_batch: str | int | None = None, transform_categorical_covariates: list[int] | None = None, use_latent_mean: bool = False, n_latent_samples: int = 1000, use_importance_sampling: bool = False, reconstructed_library_size: float = 10_000, ): """ Reconstruct the data using the VAE, optionally transforming the batch. Note: scvi-tools uses the following strategy - - for each transform_batch, put the data through the encoder (no dropout) - sample n_latent_samples times to get several z values - take the mean of the generative distribution - obtain tensor shape [n_transform_batches, n_latent_samples, n_cells, n_genes] - take a mean over the batches dimension (- they optionally use importance sampling based on sampled z likelihoods) - take a (weighted?) mean over the n_latent_samples dimension Args: x_ng: Gene counts matrix. var_names_g: The list of the variable names in the input data. gene_inds: The indices of the genes from var_names_g to be reconstructed. Output order preserves this order. batch_index_n: Batch indices of input cells as integers. continuous_covariates_nc: Continuous covariates for each cell (c-dimensional). categorical_covariate_index_nd: Categorical covariates for each cell (d-dimensional where d is the number of categorical variables). Used for the encoder; also used for the decoder when transform_categorical_covariates is None. transform_batch: If not None, transform the batch to this index before reconstruction. transform_categorical_covariates: A list of integer category indices, one per categorical covariate (in the same order as n_cats_per_cov), to fix for all cells during decoding. Must be supplied when transform_batch is not None and the decoder uses categorical covariates — use enumerate_observed_batch_covariate_combinations() to identify valid combinations present in the training data. If None, the per-cell observed categorical covariates are passed to the decoder. use_latent_mean: If True, use the mean of the latent distribution instead of sampling. n_latent_samples: The number of latent samples to use for reconstruction. use_importance_sampling: True to use importance sampling for the reconstruction, weighting each sample of the latent by its likelihood. reconstructed_library_size: The library size to use for the reconstruction, common to all cells. Returns: A dictionary with the following keys: - ``x_ng``: Model's reconstruction of the input data, possibly de-batched. The notational misnomer here is that "g" no longer stands for all genes, but the genes in `gene_inds`. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) if transform_batch is None: # make this a list of size one with the measured values as default: an actual reconstruction transformed_batch_index_n_list = [batch_index_n] else: transformed_batch_index_n_list = [] if isinstance(transform_batch, str): if transform_batch != "mean": raise ValueError( 'If transform_batch is a string, it must be "mean" which ' "will project counts into each batch and compute the mean, " "otherwise specify a particular batch using its integer index" ) for i in range(self.n_batch)[:10]: transformed_batch_index_n_list.append(torch.ones_like(batch_index_n) * i) else: if transform_batch >= self.n_batch: raise ValueError(f"transform_batch must be less than self.n_batch: {self.n_batch}") transformed_batch_index_n_list = [torch.ones_like(batch_index_n) * transform_batch] # enforce that transform_categorical_covariates is supplied when the decoder uses categorical covariates # and transform_batch is set — otherwise the decoder sees per-cell observed covariates, not the target # condition, and the reconstruction will still reflect per-cell batch identity. if ( (transform_batch is not None) and self._decoder_uses_categorical_covariates and (transform_categorical_covariates is None) ): raise ValueError( "This model's decoder uses categorical covariates, but transform_categorical_covariates was not " "supplied. When transform_batch is not None, transform_categorical_covariates must also be provided " "so that all cells are decoded under the same (real-world) batch+categoricals condition. " "Use enumerate_observed_batch_covariate_combinations() to identify valid (batch, covariate) " "combinations that were present in the training data, then supply the desired covariate indices " "as reconstruction_transform_categorical_covariates in the model config." ) batch_nb = self.batch_representation_from_batch_index(batch_index_n) categorical_covariate_np = self.categorical_onehot_from_categorical_index(categorical_covariate_index_nd) library_size_n1 = torch.log(x_ng.sum(dim=-1, keepdim=True)) # generate the latent distribution inference_outputs = self.inference( x_ng=x_ng, batch_nb=batch_nb, continuous_covariates_nc=continuous_covariates_nc, categorical_covariate_np=categorical_covariate_np, ) def _run_generative_and_scale_output( z_nk: torch.Tensor, local_batch_nb: torch.Tensor, local_categorical_covariate_np: torch.Tensor | None, ) -> torch.Tensor: # use that latent sample and the transform batch to generate data generative_outputs = self.generative( z_nk=z_nk, library_size_n1=library_size_n1, batch_nb=local_batch_nb, continuous_covariates_nc=continuous_covariates_nc, categorical_covariate_np=local_categorical_covariate_np, ) # take the mean of the distribution counts_ng = generative_outputs["px"].mean # normalize and re-scale normalized_counts_ng = counts_ng / counts_ng.sum(dim=-1, keepdim=True) scaled_counts_ng = reconstructed_library_size * normalized_counts_ng # subset to genes of interest scaled_counts_np = scaled_counts_ng[:, gene_inds] return scaled_counts_np x_tilde_np: torch.Tensor = torch.zeros(x_ng.shape[0], len(gene_inds), device=x_ng.device) # go through each output batch projection (just one unless transform_batch == "mean") for transformed_batch_index_n in transformed_batch_index_n_list: output_counts_sum_np: int | torch.Tensor = 0 importance_weight_means: list = [] local_batch_nb = self.batch_representation_from_batch_index( transformed_batch_index_n, use_mean_though_sampling=True, # don't sample during reconstruction ) # build the categorical covariate tensor for this reconstruction condition if transform_categorical_covariates is not None: fixed_idx_nd = ( torch.tensor(transform_categorical_covariates, dtype=torch.long, device=x_ng.device) .unsqueeze(0) .expand(x_ng.shape[0], -1) ) local_categorical_covariate_np = self.categorical_onehot_from_categorical_index(fixed_idx_nd) else: local_categorical_covariate_np = categorical_covariate_np if use_latent_mean: # use the mean in the latent space mean_z_nk = inference_outputs["qz"].mean # run generative model and scale output scaled_counts_np = _run_generative_and_scale_output( z_nk=mean_z_nk, local_batch_nb=local_batch_nb, local_categorical_covariate_np=local_categorical_covariate_np, ) output_counts_sum_np += scaled_counts_np x_tilde_np = x_tilde_np + output_counts_sum_np else: for _ in range(n_latent_samples): # take a sample from the latent space sampled_z_nk = inference_outputs["qz"].sample() mean_z_nk = inference_outputs["qz"].mean # compute weight for importance sampling if use_importance_sampling: importance_weight_n1 = ( ( inference_outputs["qz"].log_prob(sampled_z_nk) - inference_outputs["qz"].log_prob(mean_z_nk) ) .sum(dim=-1) .exp() .unsqueeze(-1) ) importance_weight_means.append(importance_weight_n1.squeeze().mean()) else: importance_weight_n1 = 1.0 importance_weight_means.append(importance_weight_n1) # run generative model and scale output scaled_counts_np = _run_generative_and_scale_output( z_nk=sampled_z_nk, local_batch_nb=local_batch_nb, local_categorical_covariate_np=local_categorical_covariate_np, ) output_counts_sum_np += scaled_counts_np * importance_weight_n1 x_tilde_np = x_tilde_np + output_counts_sum_np / sum(importance_weight_means) x_tilde_np = x_tilde_np / len(transformed_batch_index_n_list) return {"x_ng": x_tilde_np}
# ------------------------------------------------------------------ # Validation hooks # ------------------------------------------------------------------ def on_validation_epoch_start(self, trainer: pl.Trainer) -> None: device = next(self.parameters()).device # ELBO / reconstruction accumulators self._val_elbo_sum = torch.zeros(1, device=device) self._val_rec_sum = torch.zeros(1, device=device) self._val_n_cells = torch.zeros(1, device=device) # Ontology accumulators (only when num_classes is configured) if self.num_classes is not None: self._val_z_sum_kd = torch.zeros(self.num_classes, self.n_latent, device=device) self._val_class_count_k = torch.zeros(self.num_classes, device=device) # Batch silhouette accumulators if self.n_batch > 1: self._val_batch_z_sum_bk = torch.zeros(self.n_batch, self.n_latent, device=device) self._val_batch_z_sq_sum_b = torch.zeros(self.n_batch, device=device) self._val_batch_count_b = torch.zeros(self.n_batch, device=device) # Cell type classifier reservoirs (train = even batch_idx, test = odd batch_idx) if self.num_classes is not None: rs = self.val_cell_type_classifier_reservoir_size self._val_cl_train_z = torch.zeros(rs, self.n_latent, device=device) self._val_cl_train_y = torch.zeros(rs, dtype=torch.long, device=device) self._val_cl_train_fill = 0 self._val_cl_train_seen = 0 self._val_cl_test_z = torch.zeros(rs, self.n_latent, device=device) self._val_cl_test_y = torch.zeros(rs, dtype=torch.long, device=device) self._val_cl_test_fill = 0 self._val_cl_test_seen = 0 def validate( self, trainer: pl.Trainer, pl_module: pl.LightningModule, batch_idx: int, x_ng: torch.Tensor, var_names_g: np.ndarray, batch_index_n: torch.Tensor, continuous_covariates_nc: torch.Tensor | None = None, categorical_covariate_index_nd: torch.Tensor | None = None, total_mrna_umis_n: torch.Tensor | None = None, validation_cell_type_index_n: torch.Tensor | None = None, ) -> None: n = x_ng.shape[0] output = self( x_ng=x_ng, var_names_g=var_names_g, batch_index_n=batch_index_n, continuous_covariates_nc=continuous_covariates_nc, categorical_covariate_index_nd=categorical_covariate_index_nd, total_mrna_umis_n=total_mrna_umis_n, ) # log annealed loss for progress bar / on-step visibility if isinstance(output["loss"], torch.Tensor): pl_module.log("val_loss", output["loss"], sync_dist=True, on_epoch=True, batch_size=n) # accumulate exact ELBO (no annealing) kl_batch = output["kl_divergence_batch"] if not isinstance(kl_batch, torch.Tensor): kl_batch = torch.zeros(n, device=x_ng.device) assert isinstance(output["reconstruction_loss"], torch.Tensor) assert isinstance(output["kl_divergence_z"], torch.Tensor) assert isinstance(output["z_nk"], torch.Tensor) elbo_n = -(output["reconstruction_loss"] + output["kl_divergence_z"] + kl_batch) self._val_elbo_sum += elbo_n.sum().detach() self._val_rec_sum += output["reconstruction_loss"].sum().detach() self._val_n_cells += n z_nk = output["z_nk"].detach() # accumulate per-class latent sums for ontology metric if validation_cell_type_index_n is not None and self.num_classes is not None: idx = validation_cell_type_index_n.long() self._val_z_sum_kd.index_add_(0, idx, z_nk) self._val_class_count_k.index_add_(0, idx, torch.ones(n, device=x_ng.device)) # accumulate per-batch latent sums for silhouette metric if self.n_batch > 1: bidx = batch_index_n.long() self._val_batch_z_sum_bk.index_add_(0, bidx, z_nk) self._val_batch_z_sq_sum_b.index_add_(0, bidx, (z_nk**2).sum(dim=-1)) self._val_batch_count_b.index_add_(0, bidx, torch.ones(n, device=x_ng.device)) # reservoir sampling for cell type classifier if validation_cell_type_index_n is not None and self.num_classes is not None: if batch_idx % 2 == 0: buf_z, buf_y = self._val_cl_train_z, self._val_cl_train_y fill, seen = self._val_cl_train_fill, self._val_cl_train_seen else: buf_z, buf_y = self._val_cl_test_z, self._val_cl_test_y fill, seen = self._val_cl_test_fill, self._val_cl_test_seen rs = self.val_cell_type_classifier_reservoir_size labels = validation_cell_type_index_n.long() # phase 1: direct fill up to capacity space = rs - fill direct = min(space, n) if direct > 0: buf_z[fill : fill + direct] = z_nk[:direct] buf_y[fill : fill + direct] = labels[:direct] fill += direct # phase 2: reservoir replacement for overflow cells for i in range(direct, n): j = int(torch.randint(0, seen + i - direct + 1, (1,)).item()) if j < rs: buf_z[j] = z_nk[i] buf_y[j] = labels[i] seen += n if batch_idx % 2 == 0: self._val_cl_train_fill, self._val_cl_train_seen = fill, seen else: self._val_cl_test_fill, self._val_cl_test_seen = fill, seen @staticmethod def _weighted_spearman( latent_dists: torch.Tensor, onto_dists: torch.Tensor, counts: torch.Tensor, ) -> float: import scipy.stats m = latent_dists.shape[0] if m < 2: return float("nan") idx = torch.triu_indices(m, m, offset=1) x = latent_dists[idx[0], idx[1]].cpu().float().numpy() y = onto_dists[idx[0], idx[1]].cpu().float().numpy() w = (counts[idx[0]] * counts[idx[1]]).cpu().float().numpy() # exclude disconnected pairs (inf ontology distance — different DAG subtrees) finite_mask = np.isfinite(y) if finite_mask.sum() < 3: return float("nan") x, y, w = x[finite_mask], y[finite_mask], w[finite_mask] rx = scipy.stats.rankdata(x) ry = scipy.stats.rankdata(y) cov = np.cov(rx, ry, aweights=w) denom = float(np.sqrt(max(0.0, cov[0, 0] * cov[1, 1]))) if (denom == 0.0) or np.isnan(denom): logger.warning( "val_ontology_spearman: zero variance in ontology distances (all pairs at equal distance); logging 0.0" ) return 0.0 return float(cov[0, 1] / denom) def on_validation_epoch_end(self, lightning_module: pl.LightningModule, trainer: pl.Trainer) -> None: if trainer.sanity_checking: return # sync all accumulators across DDP ranks if torch.distributed.is_available() and torch.distributed.is_initialized(): world_size = torch.distributed.get_world_size() if world_size > 1: tensors_to_sync = [self._val_elbo_sum, self._val_rec_sum, self._val_n_cells] if self.num_classes is not None: tensors_to_sync += [self._val_z_sum_kd, self._val_class_count_k] if self.n_batch > 1: tensors_to_sync += [ self._val_batch_z_sum_bk, self._val_batch_z_sq_sum_b, self._val_batch_count_b, ] for t in tensors_to_sync: torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.SUM) # gather classifier reservoirs: exchange fill counts then concatenate valid slices if self.num_classes is not None: for buf_z, buf_y, fill_attr, seen_attr in [ (self._val_cl_train_z, self._val_cl_train_y, "_val_cl_train_fill", "_val_cl_train_seen"), (self._val_cl_test_z, self._val_cl_test_y, "_val_cl_test_fill", "_val_cl_test_seen"), ]: fill = getattr(self, fill_attr) fill_t = torch.tensor([fill], device=buf_z.device) fills = [torch.zeros_like(fill_t) for _ in range(world_size)] torch.distributed.all_gather(fills, fill_t) gathered_z = [torch.zeros_like(buf_z) for _ in range(world_size)] gathered_y = [torch.zeros_like(buf_y) for _ in range(world_size)] torch.distributed.all_gather(gathered_z, buf_z) torch.distributed.all_gather(gathered_y, buf_y) combined_z = torch.cat([g[: int(f.item())] for g, f in zip(gathered_z, fills)], dim=0) combined_y = torch.cat([g[: int(f.item())] for g, f in zip(gathered_y, fills)], dim=0) # apply reservoir sub-sampling if combined exceeds capacity rs = self.val_cell_type_classifier_reservoir_size if combined_z.shape[0] > rs: perm = torch.randperm(combined_z.shape[0], device=combined_z.device)[:rs] combined_z = combined_z[perm] combined_y = combined_y[perm] buf_z[: combined_z.shape[0]] = combined_z buf_y[: combined_y.shape[0]] = combined_y setattr(self, fill_attr, combined_z.shape[0]) # only rank 0 computes and logs the epoch-level metrics if trainer.global_rank != 0: return if trainer.logger is None: return step = trainer.global_step # ELBO and reconstruction loss n_cells = self._val_n_cells.item() if n_cells > 0: val_elbo = (self._val_elbo_sum / self._val_n_cells).item() trainer.logger.log_metrics({"val_elbo": val_elbo}, step=step) val_rec = self._val_rec_sum / self._val_n_cells lightning_module.log( name="val_reconstruction_loss", value=val_rec.detach(), sync_dist=True, ) # ontology-weighted Spearman if self.num_classes is not None and self.ontology_distance_matrix is not None: valid_mask = self._val_class_count_k > 0 n_valid = int(valid_mask.sum().item()) if n_valid >= 2: counts_valid = self._val_class_count_k[valid_mask] centroids = self._val_z_sum_kd[valid_mask] / counts_valid.unsqueeze(1) latent_dists = torch.cdist(centroids, centroids) onto_dists = self.ontology_distance_matrix[valid_mask][:, valid_mask] spearman = self._weighted_spearman(latent_dists, onto_dists, counts_valid) if not np.isnan(spearman): trainer.logger.log_metrics({"val_ontology_spearman": spearman}, step=step) # centroid-based batch silhouette if self.n_batch > 1: valid_mask = self._val_batch_count_b > 0 n_valid = int(valid_mask.sum().item()) if n_valid >= 2: counts_b = self._val_batch_count_b[valid_mask] centroids_bk = self._val_batch_z_sum_bk[valid_mask] / counts_b.unsqueeze(1) mean_sq_b = self._val_batch_z_sq_sum_b[valid_mask] / counts_b # within-batch spread: sqrt(E[||z||^2] - ||centroid||^2) intra_b = (mean_sq_b - centroids_bk.pow(2).sum(dim=-1)).clamp(min=0.0).sqrt() D = torch.cdist(centroids_bk, centroids_bk) # mask self-distances D.fill_diagonal_(float("inf")) b_b = D.min(dim=1).values # nearest other centroid distance s_b = (b_b - intra_b) / torch.max(b_b, intra_b).clamp(min=1e-12) val_silhouette = (s_b * counts_b).sum() / counts_b.sum() trainer.logger.log_metrics({"val_batch_silhouette": val_silhouette.item()}, step=step) # cell type logistic regression classifier if self.num_classes is not None and self._val_cl_train_fill >= 2 and self._val_cl_test_fill >= 1: X_train = self._val_cl_train_z[: self._val_cl_train_fill].float() y_train = self._val_cl_train_y[: self._val_cl_train_fill] X_test = self._val_cl_test_z[: self._val_cl_test_fill].float() y_test = self._val_cl_test_y[: self._val_cl_test_fill] # normalize features for stable LBFGS convergence mu = X_train.mean(0) sigma = X_train.std(0).clamp(min=1e-8) X_train = (X_train - mu) / sigma X_test = (X_test - mu) / sigma W = torch.zeros(self.num_classes, self.n_latent, device=X_train.device, requires_grad=True) b = torch.zeros(self.num_classes, device=X_train.device, requires_grad=True) opt = torch.optim.LBFGS([W, b], max_iter=200, line_search_fn="strong_wolfe") def closure(): opt.zero_grad() loss = torch.nn.functional.cross_entropy(X_train @ W.T + b, y_train) loss.backward() return loss with torch.enable_grad(): opt.step(closure) with torch.no_grad(): logits_test = X_test @ W.T + b preds = logits_test.argmax(dim=-1) top1 = (preds == y_test).float().mean().detach() lightning_module.log( name="val_cell_type_top1_accuracy", value=top1, sync_dist=True, ) if self.ontology_distance_matrix is not None: wrong_mask = preds != y_test if wrong_mask.any(): err_dist = self.ontology_distance_matrix[preds[wrong_mask], y_test[wrong_mask]] finite = err_dist.isfinite() if finite.any(): mean_err = err_dist[finite].mean().item() trainer.logger.log_metrics({"val_cell_type_mean_error_distance": mean_err}, step=step) def on_train_batch_end(self, trainer: pl.Trainer) -> None: self.step = trainer.global_step self.epoch = trainer.current_epoch # log these values to pytorch lightning logger if trainer.logger is not None: trainer.logger.log_metrics( { "kl_annealing_weight": compute_annealed_kl_weight( epoch=self.epoch, step=self.step, n_epochs_kl_warmup=self.kl_warmup_epochs, n_steps_kl_warmup=self.kl_warmup_steps, max_kl_weight=1.0, min_kl_weight=self.kl_annealing_start, ) }, step=trainer.global_step, )
def batch_index_to_batch_label(adata: AnnData, batch_keys: list[str]) -> pd.DataFrame: """ Convert integer batch index used in the model to a human-readable batch label. Args: adata: AnnData object. Can be any individual shard as long as categoricals contain all categories. batch_keys: List of batch keys. Returns: DataFrame with columns as batch covariates and an extra column "scvi_batch_code" with the code used in the model. """ logger.warning("The batch_index_to_batch_label lookup for multiple batch_keys is still experimental.") df = _enumerate_categorical_combinations(adata.obs[batch_keys]) df["scvi_batch_code"] = categories_to_product_codes(df) return df def enumerate_observed_batch_covariate_combinations( obs: pd.DataFrame, batch_keys: list[str], categorical_covariate_keys: list[str] | None = None, ) -> pd.DataFrame: """ Tabulate every observed combination of (batch, categorical covariates) in an AnnData obs DataFrame, along with the integer indices that the model uses internally for each field. This is intended to help users choose valid values for ``reconstruction_transform_batch`` and ``reconstruction_transform_categorical_covariates`` when running scVI's reconstruct method. The model requires that the specified (batch, covariate) combination was actually present in the training data; this function makes it easy to inspect which combinations exist and how many cells belong to each. Args: obs: The ``adata.obs`` DataFrame. All columns listed in ``batch_keys`` and ``categorical_covariate_keys`` must be pandas Categoricals. batch_keys: The obs column name(s) used as the batch key during training (matching the ``batch_index_n`` field in the datamodule config). The integer ``batch_index`` reported here is the value passed to the model as ``batch_index_n``, computed as ``categories_to_product_codes`` over these columns. categorical_covariate_keys: The obs column name(s) used as categorical covariates during training (matching the ``categorical_covariate_index_nd`` field in the datamodule config), in the same order. For each key, the reported ``{key}_index`` is the 0-based integer code within that covariate's categories — i.e., the value to put in position ``i`` of ``reconstruction_transform_categorical_covariates``. If None, no covariate columns are included. Returns: A DataFrame with one row per observed (batch, covariate) combination, sorted by ``cell_count`` descending. Columns are: - One label column per batch key (human-readable category label). - ``batch_index``: the integer passed to the model as ``batch_index_n``. - One label column per categorical covariate key (human-readable category label). - One ``{key}_index`` column per categorical covariate key (0-based integer code within that covariate's categories, for use as ``reconstruction_transform_categorical_covariates[i]``). - ``cell_count``: number of cells with that combination. """ work = obs[batch_keys].copy() for k in batch_keys: if not hasattr(work[k], "cat"): raise ValueError(f"batch_keys column '{k}' must be a pandas Categorical.") # compute batch_index using the same logic as the training pipeline batch_index_series = pd.Series( categories_to_product_codes(work[batch_keys] if len(batch_keys) > 1 else work[batch_keys[0]]), index=obs.index, name="batch_index", ) work["batch_index"] = batch_index_series cov_index_cols: list[str] = [] if categorical_covariate_keys is not None: for k in categorical_covariate_keys: if not hasattr(obs[k], "cat"): raise ValueError(f"categorical_covariate_keys column '{k}' must be a pandas Categorical.") idx_col = f"{k}_index" work[k] = obs[k] work[idx_col] = obs[k].cat.codes.values cov_index_cols.append(idx_col) # group by label columns + index columns; cell_count is the group size group_cols = batch_keys + ["batch_index"] if categorical_covariate_keys is not None: for k, idx_col in zip(categorical_covariate_keys, cov_index_cols): group_cols += [k, idx_col] result = ( work.groupby(group_cols, observed=True, sort=False) .size() .reset_index(name="cell_count") .sort_values("cell_count", ascending=False) .reset_index(drop=True) ) return result def _n_cats_per_column(df: pd.DataFrame) -> list[int]: """ Return the number of categories for each column in a DataFrame, assuming all columns are categorical. """ n_cats_per_col = [] for key in df.columns: covariate_series = df[key] n_cats_per_col.append(len(covariate_series.cat.categories)) return n_cats_per_col def _enumerate_categorical_combinations(df: pd.DataFrame) -> pd.DataFrame: """ Enumerate all possible combinations of categories in a DataFrame of categorical covariates. """ categories_per_column = _n_cats_per_column(df) # Generate the range of values for each column based on the number of categories category_ranges = [range(c) for c in categories_per_column] # Use itertools.product to get all possible combinations combinations = list(itertools.product(*category_ranges)) # Convert to DataFrame for easy handling and return enumerated_df = pd.DataFrame(combinations, columns=df.columns) for c in df.columns: lookup = dict(zip(range(len(df[c].cat.categories)), df[c].cat.categories)) enumerated_df[c] = enumerated_df[c].map(lookup) enumerated_df[c] = enumerated_df[c].astype("category") enumerated_df[c] = enumerated_df[c].cat.set_categories(df[c].cat.categories) return enumerated_df