# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
"""Flexible modified version of single-cell variational inference (scVI) re-implemented in Cellarium ML."""
import importlib
import itertools
import logging
from abc import abstractmethod
from typing import Any, Literal, Sequence
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch
import zuko.flows
from anndata import AnnData
from torch.distributions import Distribution, Normal, Poisson
from torch.distributions import kl_divergence as kl
from cellarium.ml.distributions import NegativeBinomial
from cellarium.ml.layers import DressedLayer, FullyConnectedLinear
from cellarium.ml.models.model import CellariumModel, PredictMixin, ValidateMixin
from cellarium.ml.utilities.data import categories_to_product_codes
from cellarium.ml.utilities.testing import (
assert_arrays_equal,
assert_columns_and_array_lengths_equal,
)
logger = logging.getLogger(__name__)
def class_from_class_path(class_path: str):
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
class_ref = getattr(module, class_name)
return class_ref
def instantiate_from_class_path(class_path, *args, **kwargs):
class_ = class_from_class_path(class_path)
return class_(*args, **kwargs)
def weights_init(m):
if isinstance(m, torch.nn.BatchNorm1d):
torch.nn.init.normal_(m.weight, 1.0, 0.02)
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.Linear) or isinstance(m, LinearWithStructuredBias):
torch.nn.init.xavier_normal_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
class LinearWithStructuredBias(torch.nn.Linear):
"""A `torch.nn.Linear` layer where batch indices are given as input to the forward pass.
Args:
in_features: passed to `torch.nn.Linear`
out_features: passed to `torch.nn.Linear`
n_batch: the dimensionality of the batch representation
categorical_covariate_dimensions: a list of integers containing the number of categories
for each categorical covariate
label_to_bias_hidden_layers: a list of hidden layer sizes for the label-to-bias decoder
bias: passed to `torch.nn.Linear` (True is like the scvi-tools implementation)
label_to_bias_dressing_init_kwargs: a dictionary of keyword arguments to pass to
the `DressedLayer` constructor
"""
def __init__(
self,
in_features: int,
out_features: int,
label_to_bias_hidden_layers: list[int],
n_batch: int = 0,
categorical_covariate_dimensions: list[int] | None = None,
bias: bool = True,
label_to_bias_dressing_init_kwargs: dict[str, Any] | None = None,
):
super().__init__(in_features, out_features, bias=bias)
if categorical_covariate_dimensions is None:
categorical_covariate_dimensions = []
if label_to_bias_dressing_init_kwargs is None:
label_to_bias_dressing_init_kwargs = {}
if n_batch + sum(categorical_covariate_dimensions) < 1:
raise ValueError("in_features=0: at least one batch or categorical covariate dimension must be provided.")
self.bias_decoder = FullyConnectedLinear(
in_features=n_batch + sum(categorical_covariate_dimensions),
out_features=out_features,
n_hidden=label_to_bias_hidden_layers,
dressing_init_kwargs=label_to_bias_dressing_init_kwargs,
)
@abstractmethod
def compute_bias(
self,
batch_nb: torch.Tensor,
categorical_covariate_np: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Returns the bias given batch representations.
Args:
batch_nb: a tensor of batch representations (could be one-hot) of shape (n, batch_latent_dim)
categorical_covariate_np: a tensor of categorical covariates of shape (n, sum(n_categories_per_covariate))
Returns:
a tensor of shape (n, out_features)
"""
pass
def forward( # type: ignore[override]
self,
x_ng: torch.Tensor,
batch_nb: torch.Tensor,
categorical_covariate_np: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Computes the forward pass of the layer as
out = x @ self.weight.T + self.bias + bias
where bias is computed as
bias = bias_encoder(batch)
Args:
x_ng: a tensor of shape (n, in_features)
batch_nb: a tensor of batch indices of shape (n, batch_latent_dim)
categorical_covariate_np: a tensor of categorical covariates of shape (n, sum(n_categories_per_covariate))
"""
return super().forward(x_ng) + self.compute_bias(
batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np
)
class LinearWithBatchAndCovariates(LinearWithStructuredBias):
def compute_bias(
self,
batch_nb: torch.Tensor,
categorical_covariate_np: torch.Tensor | None = None,
) -> torch.Tensor:
if categorical_covariate_np is None:
raise ValueError("Categorical covariates must be provided to LinearWithBatchAndCovariates")
else:
return self.bias_decoder(torch.cat([batch_nb, categorical_covariate_np], dim=-1))
class LinearWithBatch(LinearWithStructuredBias):
def compute_bias(
self,
batch_nb: torch.Tensor,
categorical_covariate_np: torch.Tensor | None = None,
) -> torch.Tensor:
return self.bias_decoder(batch_nb)
class LinearWithCovariates(LinearWithStructuredBias):
def compute_bias(
self,
batch_nb: torch.Tensor,
categorical_covariate_np: torch.Tensor | None = None,
) -> torch.Tensor:
if categorical_covariate_np is None:
raise ValueError("Categorical covariates must be provided to LinearWithCovariates")
else:
return self.bias_decoder(categorical_covariate_np)
class FullyConnectedWithBatchArchitecture(torch.nn.Module):
"""
Fully connected block of layers (can be empty) that can include LinearWithBatch layers.
The forward pass takes per-cell batches.
Args:
in_features: The dimensionality of the input
layers: A list of dictionaries, each containing the following keys:
* ``class_path``: the class path of the layer to use
* ``init_args``: a dictionary of keyword arguments to pass to the layer's constructor
- must contain "out_features"
"""
def __init__(
self,
in_features: int,
layers: list[dict],
):
super().__init__()
for layer in layers:
assert "out_features" in layer["init_args"], """
"out_features" must be specified in init_args for hidden layers, e.g.
- class_path: cellarium.ml.models.scvi.LinearWithBatch
init_args:
out_features: 128
"""
if len(layers) == 0:
module_list = torch.nn.ModuleList([torch.nn.Identity()])
out_features = in_features
else:
module_list = torch.nn.ModuleList()
n_hidden = [layer["init_args"].get("out_features") for layer in layers]
for layer, n_in, n_out in zip(layers, [in_features] + n_hidden, n_hidden):
layer["init_args"]["out_features"] = n_out
module_list.append(
DressedLayer(
instantiate_from_class_path(
layer["class_path"],
in_features=n_in,
bias=True,
**layer["init_args"],
),
**layer["dressing_init_args"],
)
)
assert hasattr(module_list[-1].layer, "out_features") and isinstance(
module_list[-1].layer.out_features, int
)
out_features = module_list[-1].layer.out_features
self.module_list = module_list
self.out_features = out_features
def forward(
self,
x_ng: torch.Tensor,
batch_nb: torch.Tensor,
categorical_covariate_np: torch.Tensor | None,
) -> torch.Tensor:
x_ = x_ng
for dressed_layer in self.module_list:
x_ = (
dressed_layer(x_, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
if (hasattr(dressed_layer, "layer") and isinstance(dressed_layer.layer, LinearWithStructuredBias))
else dressed_layer(x_)
)
return x_
class EncoderSCVI(torch.nn.Module):
"""
Encode data of ``in_features`` dimensions into a latent space of ``out_features`` dimensions.
Args:
in_features: The dimensionality of the input (data space)
out_features: The dimensionality of the output (latent space)
hidden_layers: A list of dictionaries, each containing the following keys:
* ``class_path``: the class path of the layer to use
* ``init_args``: a dictionary of keyword arguments to pass to the layer's constructor
- must contain "out_features"
final_layer: Same as hidden_layers, but for the final layer
output_bias: If True, the output layer will have a batch-specific bias added
(scvi-tools does not include this)
var_eps: Minimum value for the variance; used for numerical stability
"""
def __init__(
self,
in_features: int,
out_features: int,
hidden_layers: list[dict],
final_layer: dict,
var_eps: float = 1e-4,
):
super().__init__()
self.fully_connected = FullyConnectedWithBatchArchitecture(in_features, hidden_layers)
self.mean_encoder = instantiate_from_class_path(
final_layer["class_path"],
in_features=self.fully_connected.out_features,
out_features=out_features,
bias=final_layer["init_args"].get("bias", True),
**final_layer["init_args"],
)
self.var_encoder = instantiate_from_class_path(
final_layer["class_path"],
in_features=self.fully_connected.out_features,
out_features=out_features,
bias=final_layer["init_args"].get("bias", True),
**final_layer["init_args"],
)
self.mean_encoder_injects_covariates = isinstance(self.mean_encoder, LinearWithStructuredBias)
self.var_eps = var_eps
def forward(
self,
x_ng: torch.Tensor,
batch_nb: torch.Tensor,
categorical_covariate_np: torch.Tensor | None,
) -> Distribution:
q_nh = self.fully_connected(x_ng, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
q_mean_nk = (
self.mean_encoder(q_nh, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
if self.mean_encoder_injects_covariates
else self.mean_encoder(q_nh)
)
q_var_nk = (
torch.exp(
self.var_encoder(q_nh, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
if self.mean_encoder_injects_covariates
else self.var_encoder(q_nh)
)
+ self.var_eps
)
return Normal(q_mean_nk, q_var_nk.sqrt())
class DecoderSCVI(torch.nn.Module):
"""
Decode data of ``in_features`` latent dimensions into data space of ``out_features`` dimensions.
Args:
in_features: The dimensionality of the input (latent space)
out_features: The dimensionality of the output (data space)
hidden_layers: A list of dictionaries, each containing the following keys:
* ``class_path``: the class path of the layer to use
* ``init_args``: a dictionary of keyword arguments to pass to the layer's constructor
- must contain "out_features"
final_layer: Same as hidden_layers, but for the final layer
dispersion: Granularity at which the overdispersion of the negative binomial distribution is computed
gene_likelihood: Distribution to use for reconstruction in the generative process
scale_activation: Activation layer to use to compute normalized counts (before multiplying by library size)
final_additive_bias: If True, the final layer will have a batch-specific bias added after the activation.
If final_layer is a LinearWithBatch layer and final_additive_bias is True, the last layer of the decoder
will act as a batch-specific affine transformation.
eps: Numerical stability factor added to mean and inverse overdispersion of negative binomial
"""
def __init__(
self,
in_features: int,
out_features: int,
hidden_layers: list[dict],
final_layer: dict,
n_batch: int,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb",
scale_activation: Literal["softmax", "softplus"] = "softmax",
final_additive_bias: bool = False,
n_cats_per_cov: list[int] | None = None,
eps: float = 1e-10,
):
super().__init__()
self.eps = eps
self.n_batch = n_batch
self.n_cats_per_cov = n_cats_per_cov
if gene_likelihood == "zinb":
raise NotImplementedError("Zero-inflated negative binomial not yet implemented")
self.gene_likelihood = gene_likelihood
self.fully_connected = FullyConnectedWithBatchArchitecture(in_features, hidden_layers)
self.inverse_overdispersion_decoder = (
torch.nn.Linear(self.fully_connected.out_features, out_features)
if ((gene_likelihood != "poisson") and (dispersion == "gene-cell"))
else None
)
self.dropout_decoder = (
torch.nn.Linear(self.fully_connected.out_features, out_features) if (gene_likelihood == "zinb") else None
)
final_layer_init_args = final_layer["init_args"]
self.normalized_count_decoder = instantiate_from_class_path(
final_layer["class_path"],
in_features=self.fully_connected.out_features,
out_features=out_features,
bias=final_layer_init_args.get("bias", True),
**{k: v for k, v in final_layer_init_args.items() if k != "bias"},
)
self.count_decoder_takes_batch = isinstance(self.normalized_count_decoder, LinearWithStructuredBias)
self.normalized_count_activation = (
torch.nn.Softmax(dim=-1) if (scale_activation == "softmax") else torch.nn.Softplus()
)
self.final_additive_bias = final_additive_bias
if self.n_cats_per_cov is None:
categorical_features = 0
else:
categorical_features = sum(self.n_cats_per_cov)
self.final_additive_bias_layer: torch.nn.Sequential | None = None
if self.final_additive_bias:
self.final_additive_bias_layer = torch.nn.Sequential(
FullyConnectedLinear(
in_features=self.n_batch + categorical_features,
out_features=out_features,
n_hidden=[],
dressing_init_kwargs={},
),
torch.nn.ReLU(),
)
def forward(
self,
z_nk: torch.Tensor,
batch_nb: torch.Tensor,
inverse_overdispersion: torch.Tensor | None,
library_size_n1: torch.Tensor,
categorical_covariate_np: torch.Tensor | None = None,
) -> Distribution:
# bulk of the network
q_nh = self.fully_connected(
z_nk,
batch_nb=batch_nb,
categorical_covariate_np=categorical_covariate_np,
)
# mean counts
unnormalized_chi_ng = (
self.normalized_count_decoder(q_nh, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
if self.count_decoder_takes_batch
else self.normalized_count_decoder(q_nh)
)
chi_ng = self.normalized_count_activation(unnormalized_chi_ng)
if self.final_additive_bias_layer is not None:
count_mean_ng = torch.exp(library_size_n1) * chi_ng + self.final_additive_bias_layer(
torch.cat([batch_nb, categorical_covariate_np], dim=-1)
if categorical_covariate_np is not None
else batch_nb
)
else:
count_mean_ng = torch.exp(library_size_n1) * chi_ng
# construct the count distribution
dist: Distribution
match self.gene_likelihood:
case "nb":
if inverse_overdispersion is None:
assert self.inverse_overdispersion_decoder is not None, (
"inverse_overdispersion must be provided when not using Poisson or gene-cell dispersion"
)
inverse_overdispersion = self.inverse_overdispersion_decoder(q_nh).exp()
dist = NegativeBinomial(count_mean_ng + self.eps, inverse_overdispersion + self.eps)
case "poisson":
dist = Poisson(count_mean_ng + self.eps)
case "zinb":
raise NotImplementedError("ZINB is not currently implemented")
# dist = ZeroInflatedNegativeBinomial(count_mean_ng, inverse_overdispersion, self.dropout_decoder(q_nh))
return dist
def compute_annealed_kl_weight(
epoch: int,
step: int,
n_epochs_kl_warmup: int | None,
n_steps_kl_warmup: int | None,
max_kl_weight: float = 1.0,
min_kl_weight: float = 0.0,
) -> float:
"""Computes the kl weight for the current step or epoch.
If both `n_epochs_kl_warmup` and `n_steps_kl_warmup` are None `max_kl_weight` is returned.
Args:
epoch: Current epoch.
step: Current step.
n_epochs_kl_warmup: Number of epochs to warm up the KL weight.
n_steps_kl_warmup: Number of steps to warm up the KL weight.
max_kl_weight: Maximum KL weight.
min_kl_weight: Minimum KL weight.
Returns:
The KL weight for the current step or epoch.
"""
if min_kl_weight > max_kl_weight:
raise ValueError(f"min_kl_weight={min_kl_weight} is larger than max_kl_weight={max_kl_weight}.")
slope = max_kl_weight - min_kl_weight
if n_epochs_kl_warmup:
if epoch < n_epochs_kl_warmup:
return slope * (epoch / n_epochs_kl_warmup) + min_kl_weight
elif n_steps_kl_warmup:
if step < n_steps_kl_warmup:
return slope * (step / n_steps_kl_warmup) + min_kl_weight
return max_kl_weight
[docs]
class SingleCellVariationalInference(CellariumModel, PredictMixin, ValidateMixin):
"""
Flexible version of single-cell variational inference (scVI) [1] re-implemented in Cellarium ML.
**References:**
1. `Deep generative modeling for single-cell transcriptomics (Lopez et al.)
<https://www.nature.com/articles/s41592-018-0229-2>`_.
Args:
var_names_g: The variable names schema for the input data validation.
encoder: Dict specifying the encoder configuration.
decoder: Dict specifying the decoder configuration.
n_latent: Dimension of the latent space.
n_batch: Number of total batches in the dataset.
batch_representation_sampled: True to sample latent batch from a distribution.
n_continuous_cov: Number of continuous covariates.
n_cats_per_cov: A list of integers containing the number of categories for each categorical covariate.
dropout_rate: Dropout rate for hidden units in the encoder only.
input_gene_dropout_rate: Gene dropout rate for input data that goes into the encoder during training.
dispersion: Flexibility of the dispersion parameter when ``gene_likelihood`` is either ``"nb"`` or
``"zinb"``. One of the following:
* ``"gene"``: parameter is constant per gene across cells.
* ``"gene-batch"``: parameter is constant per gene per batch.
* ``"gene-label"``: parameter is constant per gene per label.
* ``"gene-cell"``: parameter is constant per gene per cell.
log_variational: If ``True``, use :func:`~torch.log1p` on input data before encoding for numerical stability
(not normalization).
gene_likelihood: Distribution to use for reconstruction in the generative process. One of the following:
* ``"nb"``: :class:`~scvi.distributions.NegativeBinomial`.
* ``"zinb"``: :class:`~scvi.distributions.ZeroInflatedNegativeBinomial`. (not implemented)
* ``"poisson"``: :class:`~scvi.distributions.Poisson`.
latent_distribution: Distribution to use for the latent space. One of the following:
* ``"normal"``: isotropic normal.
* ``"ln"``: logistic normal with normal params N(0, 1). (not implemented)
use_batch_norm: Specifies where to use :class:`~torch.nn.BatchNorm1d` in the model. One of the following:
* ``"none"``: don't use batch norm in either encoder(s) or decoder.
* ``"encoder"``: use batch norm only in the encoder(s).
* ``"decoder"``: use batch norm only in the decoder.
* ``"both"``: use batch norm in both encoder(s) and decoder.
use_layer_norm: Specifies where to use :class:`~torch.nn.LayerNorm` in the model. One of the following:
* ``"none"``: don't use layer norm in either encoder(s) or decoder.
* ``"encoder"``: use layer norm only in the encoder(s).
* ``"decoder"``: use layer norm only in the decoder.
* ``"both"``: use layer norm in both encoder(s) and decoder.
Note: only one of use_batch_norm or use_layer_norm should be specified.
use_size_factor_key: If ``True``, use the :attr:`~anndata.AnnData.obs` column as defined by the
``size_factor_key`` parameter in the model's ``setup_anndata`` method as the scaling
factor in the mean of the conditional distribution. If ``False``, the observed library
size (log of the sum of counts per cell) is used. Should be ``False``.
reconstruct_counts_on_predict: Changes the behavior of :meth:`predict`. True will reconstruct gene
expression count data, False will return the latent representations
reconstruction_var_names_g: List of var_names to be reconstructed (outputs are dense matrices)
reconstruction_transform_batch: None will reconstruct in the original data batch. This is like
imputation or smoothing. An integer will reconstruct counts in the specified batch index.
The string "mean" will reconstruct counts in the first 10 batches and return the mean.
reconstruction_n_latent_samples: Number of latent samples to use for reconstruction. Each latent sample
will be used to compute the mean of the generative distribution, and the final output will be the
mean of those.
reconstruction_use_latent_mean: True to use the mean of the latent distribution rather than sampling.
reconstruction_use_importance_sampling: True to use importance sampling weighted by each latent sample's
likelihood.
reconstructed_library_size: The library size to use for the reconstruction, common to all cells.
use_flow: If True, use a Neural Spline Flow (NSF) as the prior on the latent space instead of the
standard normal N(0, I). The flow is unconditional (batch-blind) and is jointly trained with the
encoder/decoder via an MC-KL estimate: E_q[log q(z|x) - log p_flow(z)].
flow_hidden_features: Hidden layer widths for the NSF. Only used when ``use_flow=True``.
cell_type_categories: Ordered list of CL ID strings (e.g. ``["CL:0000540", ...]``) that
matches ``adata.obs[cell_type_col].cat.categories`` exactly (same order), so that
integer codes from ``.cat.codes`` map directly to rows of the internal distance
buffer. Required if ``ontology_distance_matrix`` is provided.
ontology_distance_matrix: Square :class:`pandas.DataFrame` with CL ID strings as both
index and columns, as returned by
:func:`~cellarium.ml.utilities.data.compute_cl_distance_matrix`. The constructor
will slice and reorder this to match ``cell_type_categories``. Enables the
frequency-weighted Spearman correlation metric (``val_ontology_spearman``) during
validation. Not saved to checkpoints.
val_cell_type_classifier_reservoir_size: Maximum number of cells to retain per split
(train / test) for the logistic regression cell type classifier. Reservoir sampling
is used so this bound is respected regardless of validation set size. Train cells
are drawn from even-numbered validation batches; test cells from odd-numbered
batches. Ignored when ``cell_type_categories`` is not provided. Default 50_000.
"""
def __init__(
self,
var_names_g: Sequence[str],
encoder: dict[str, list[dict] | dict | bool],
decoder: dict[str, list[dict] | dict | bool],
n_batch: int = 0,
n_latent: int = 10,
n_continuous_cov: int = 0,
n_cats_per_cov: list[int] | None = None,
dropout_rate: float = 0.1,
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
log_variational: bool = True,
gene_likelihood: Literal["zinb", "nb", "poisson"] = "nb",
latent_distribution: Literal["normal", "ln"] = "normal",
batch_embedded: bool = False,
batch_representation_sampled: bool = False,
n_latent_batch: int | None = None,
z_kl_weight_max: float = 1.0,
batch_kl_weight_max: float = 0.0,
input_gene_dropout_rate: float = 0.0,
use_batch_norm: Literal["encoder", "decoder", "none", "both"] = "both",
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
kl_warmup_epochs: int | None = 400,
kl_warmup_steps: int | None = None,
kl_annealing_start: float = 0.0,
use_size_factor_key: bool = False,
reconstruct_counts_on_predict: bool = False,
reconstruction_var_names_g: np.ndarray | list | None = None,
reconstruction_transform_batch: None | int | str = 0,
reconstruction_n_latent_samples: int = 30,
reconstruction_use_latent_mean: bool = False,
reconstruction_use_importance_sampling: bool = False,
reconstructed_library_size: int = 10_000,
reconstruction_transform_categorical_covariates: list[int] | None = None,
use_flow: bool = False,
flow_hidden_features: list[int] = [64, 64],
cell_type_categories: list[str] | None = None,
ontology_distance_matrix: pd.DataFrame | None = None,
val_cell_type_classifier_reservoir_size: int = 50_000,
):
super().__init__()
self.var_names_g = np.array(var_names_g)
self.n_input = len(self.var_names_g)
self.dispersion = dispersion
self.n_latent = n_latent
self.n_batch = n_batch
self.log_variational = log_variational
self.gene_likelihood = gene_likelihood
self.input_gene_dropout_rate = input_gene_dropout_rate
self.latent_distribution = latent_distribution
self.n_cats_per_cov = n_cats_per_cov if n_cats_per_cov is not None else []
self.use_size_factor_key = use_size_factor_key
self.batch_embedded = batch_embedded
self.batch_representation_sampled = batch_representation_sampled
self.n_latent_batch = n_latent_batch
if not (kl_annealing_start >= 0.0 and kl_annealing_start <= 1.0):
raise ValueError(f"kl_annealing_start={kl_annealing_start} must be in the range [0.0, 1.0].")
self.kl_annealing_start = kl_annealing_start
assert not ((kl_warmup_steps is not None) and (kl_warmup_epochs is not None)), (
"Only one of kl_warmup_epochs or kl_warmup_steps can be specified, not both."
)
self.kl_warmup_epochs = kl_warmup_epochs
self.kl_warmup_steps = kl_warmup_steps
assert batch_kl_weight_max >= 0.0, "batch_kl_weight must be non-negative"
self.batch_kl_weight_max = batch_kl_weight_max
assert z_kl_weight_max >= 0.0, "z_kl_weight must be non-negative"
self.z_kl_weight_max = z_kl_weight_max
self.epoch = 0
self.step = 0
self.reconstruction_var_names_g = reconstruction_var_names_g
if reconstruction_var_names_g is not None:
allowed_var_names_g = []
model_var_names = set(self.var_names_g)
for var_name in reconstruction_var_names_g:
if var_name not in model_var_names:
import warnings
warnings.warn(
f"reconstruction_var_names_g {var_name} not found in model var_names_g. "
"It will not be included in the reconstructed output. Check "
"model.reconstruction_var_names_g for the included var_names.",
UserWarning,
)
else:
allowed_var_names_g.append(var_name)
self.reconstruction_var_names_g = allowed_var_names_g
else:
if reconstruct_counts_on_predict:
import warnings
warnings.warn(
"reconstruction_var_names_g not specified, so all var_names_g will be reconstructed on predict.",
UserWarning,
)
self.reconstruction_var_names_g = self.var_names_g
self.reconstruct_counts_on_predict = reconstruct_counts_on_predict
self.reconstruction_transform_batch = reconstruction_transform_batch
self.reconstruction_n_latent_samples = reconstruction_n_latent_samples
self.reconstruction_use_latent_mean = reconstruction_use_latent_mean
self.reconstruction_use_importance_sampling = reconstruction_use_importance_sampling
self.reconstructed_library_size = reconstructed_library_size
self.reconstruction_transform_categorical_covariates = reconstruction_transform_categorical_covariates
self.use_flow = use_flow
self.flow_hidden_features = flow_hidden_features
# optional validation data metrics setup
# ontology_distance_matrix without cell_type_categories: matrix is silently ignored
# (can happen when the CLI links cell_type_categories from data but the batch key is absent)
if cell_type_categories is None and ontology_distance_matrix is not None:
logger.warning(
"ontology_distance_matrix was provided but cell_type_categories is None; "
"the matrix will be ignored and ontology-based metrics will be disabled."
)
ontology_distance_matrix = None
self.cell_type_categories = list(cell_type_categories) if cell_type_categories is not None else None
self.num_classes = len(cell_type_categories) if cell_type_categories is not None else None
self.val_cell_type_classifier_reservoir_size = val_cell_type_classifier_reservoir_size
if cell_type_categories is not None and ontology_distance_matrix is not None:
missing = [c for c in cell_type_categories if c not in ontology_distance_matrix.index]
if missing:
raise ValueError(
f"The following cell type categories are absent from ontology_distance_matrix: {missing}"
)
sub = ontology_distance_matrix.loc[cell_type_categories, cell_type_categories]
self._ontology_matrix_numpy: np.ndarray | None = sub.to_numpy(dtype=np.float32)
self.register_buffer(
"ontology_distance_matrix", torch.tensor(self._ontology_matrix_numpy), persistent=False
)
else:
self._ontology_matrix_numpy = None
self.register_buffer("ontology_distance_matrix", None, persistent=False)
if n_continuous_cov > 0:
raise NotImplementedError("Continuous covariates are not yet implemented")
if gene_likelihood == "zinb":
raise NotImplementedError("Zero-inflated negative binomial not yet implemented")
if latent_distribution == "ln":
raise NotImplementedError("Logistic normal latent distribution is not yet implemented")
# if you use one-hot and try to specify a different latent batch than n_batch, raise an error
if (not self.batch_embedded) and (self.n_latent_batch is not None) and (self.n_latent_batch != self.n_batch):
raise ValueError("n_latent_batch must be equal to n_batch if batch_embedded is False")
# if n_latent_batch not specified, set it to n_batch
if self.n_latent_batch is None:
self.n_latent_batch = self.n_batch # same dim as one-hot would be
# handle the embedded batch posterior
# initialize the means as one-hot, std as 1 (after exp)
if not self.batch_embedded:
self.batch_representation_mean_bd: torch.nn.Parameter | None = None
self.batch_representation_std_unconstrained_bd: torch.nn.Parameter | None = None
else:
self.batch_representation_mean_bd = torch.nn.Parameter(torch.eye(self.n_batch, self.n_latent_batch))
self.batch_representation_std_unconstrained_bd = torch.nn.Parameter(
torch.zeros(self.n_batch, self.n_latent_batch)
)
if self.dispersion == "gene":
self.px_r = torch.nn.Parameter(torch.randn(self.n_input))
elif self.dispersion == "gene-label":
raise NotImplementedError
# self.px_r = torch.nn.Parameter(torch.randn(self.n_input, n_labels))
elif self.dispersion == "gene-cell":
self.px_r = torch.nn.Parameter(torch.zeros(1)) # dummy
else:
raise ValueError(
"dispersion must be one of ['gene', 'gene-label', 'gene-cell'], but input was {}".format(
self.dispersion
)
)
use_batch_norm_encoder = use_batch_norm == "encoder" or use_batch_norm == "both"
use_batch_norm_decoder = use_batch_norm == "decoder" or use_batch_norm == "both"
use_layer_norm_encoder = use_layer_norm == "encoder" or use_layer_norm == "both"
use_layer_norm_decoder = use_layer_norm == "decoder" or use_layer_norm == "both"
def _fill_in_layer_input_args(layer: dict):
if layer["class_path"] == "cellarium.ml.models.scvi.LinearWithBatch":
layer["init_args"]["n_batch"] = self.n_latent_batch
layer["init_args"]["categorical_covariate_dimensions"] = []
elif layer["class_path"] == "cellarium.ml.models.scvi.LinearWithCovariates":
layer["init_args"]["n_batch"] = 0
layer["init_args"]["categorical_covariate_dimensions"] = self.n_cats_per_cov
elif layer["class_path"] == "cellarium.ml.models.scvi.LinearWithBatchAndCovariates":
layer["init_args"]["n_batch"] = self.n_latent_batch
layer["init_args"]["categorical_covariate_dimensions"] = self.n_cats_per_cov
def _fill_in_layer_defaults(layer: dict, batch_norm: bool, layer_norm: bool, dropout_p: float):
if "dressing_init_args" not in layer:
layer["dressing_init_args"] = {}
if "use_batch_norm" not in layer["dressing_init_args"]:
logger.info(f"use_batch_norm not specified individually in hidden layer, setting to {batch_norm}")
layer["dressing_init_args"]["use_batch_norm"] = batch_norm
if "use_layer_norm" not in layer["dressing_init_args"]:
logger.info(f"use_layer_norm not specified individually in hidden layer, setting to {layer_norm}")
layer["dressing_init_args"]["use_layer_norm"] = layer_norm
if "dropout_rate" not in layer["dressing_init_args"]:
logger.info(f"dropout_rate not specified individually in hidden layer, setting to {dropout_p}")
layer["dressing_init_args"]["dropout_rate"] = dropout_p
# encoder layers
assert isinstance(encoder["hidden_layers"], list), "encoder hidden_layers must be a list"
for layer in encoder["hidden_layers"]:
_fill_in_layer_input_args(layer)
_fill_in_layer_defaults(
layer,
batch_norm=use_batch_norm_encoder,
layer_norm=use_layer_norm_encoder,
dropout_p=dropout_rate,
)
assert isinstance(encoder["final_layer"], dict)
_fill_in_layer_input_args(encoder["final_layer"])
# decoder layers
assert isinstance(decoder["hidden_layers"], list), "decoder hidden_layers must be a list"
for layer in decoder["hidden_layers"]:
_fill_in_layer_input_args(layer)
if "dressing_init_args" in layer and "dropout_rate" in layer["dressing_init_args"]:
if layer["dressing_init_args"]["dropout_rate"] != 0.0:
logger.warning(
"Dropout is not supported in the decoder of scVI. "
"dropout_rate is being set to 0.0 in all decoder hidden layers."
)
layer["dressing_init_args"]["dropout_rate"] = 0.0 # scvi-tools does not use dropout in the decoder
_fill_in_layer_defaults(
layer,
batch_norm=use_batch_norm_decoder,
layer_norm=use_layer_norm_decoder,
dropout_p=0,
)
assert isinstance(decoder["final_layer"], dict)
_fill_in_layer_input_args(decoder["final_layer"])
self.z_encoder = EncoderSCVI(
in_features=self.n_input,
out_features=self.n_latent,
hidden_layers=encoder["hidden_layers"],
final_layer=encoder["final_layer"],
)
assert isinstance(decoder["final_additive_bias"], bool)
self.decoder = DecoderSCVI(
in_features=self.n_latent,
out_features=self.n_input,
hidden_layers=decoder["hidden_layers"],
final_layer=decoder["final_layer"],
dispersion=self.dispersion,
gene_likelihood=self.gene_likelihood,
scale_activation="softplus" if use_size_factor_key else "softmax",
final_additive_bias=decoder["final_additive_bias"],
n_batch=self.n_latent_batch, # for the (optional) sizing of the final additive bias layer
n_cats_per_cov=self.n_cats_per_cov, # for the (optional) sizing of the final additive bias layer
)
if use_flow:
# NSF.__init__ calls torch.unique(..., dim=0) which is not supported on Meta tensors.
# Lightning/jsonargparse instantiates models under a torch.device("meta") context to
# avoid allocating real memory during CLI parsing. Wrapping with torch.device("cpu")
# forces all intermediate tensors in the NSF constructor to be concrete CPU tensors.
# The Trainer will move the module to the correct device via .to() before training.
with torch.device("cpu"):
self.flow: zuko.flows.NSF | None = zuko.flows.NSF(
features=self.n_latent, context=0, hidden_features=flow_hidden_features
)
else:
self.flow = None
# detect whether the decoder injects categorical covariates at any point
self._decoder_uses_categorical_covariates: bool = any(
isinstance(m, (LinearWithCovariates, LinearWithBatchAndCovariates)) for m in self.decoder.modules()
) or (
self.decoder.final_additive_bias_layer is not None
and len(self.n_cats_per_cov) > 0
and sum(self.n_cats_per_cov) > 0
)
self.reset_parameters()
def reset_parameters(self) -> None:
for m in self.modules():
m.apply(weights_init)
torch.nn.init.normal_(self.px_r, mean=0.0, std=1.0)
if self.batch_representation_mean_bd is not None and self.batch_representation_std_unconstrained_bd is not None:
assert isinstance(self.n_latent_batch, int) # mypy
with torch.no_grad():
self.batch_representation_mean_bd.data.copy_(torch.eye(self.n_batch, self.n_latent_batch))
self.batch_representation_std_unconstrained_bd.data.fill_(0.0)
if self.flow is not None:
# Reinitialize the NSF with fresh random weights on CPU, then copy to the target device.
# This is necessary because configure_model() moves the whole model with to_empty() when
# any parameter is on meta device (encoder/decoder params), which leaves NSF parameters
# as uninitialized memory even though the NSF was constructed on CPU. weights_init()
# doesn't cover NSF's spline parameters, so they would stay uninitialized without this.
target_device = next(self.flow.parameters()).device
with torch.device("cpu"):
fresh_flow = zuko.flows.NSF(
features=self.n_latent, context=0, hidden_features=self.flow_hidden_features
)
self.flow.load_state_dict({k: v.to(target_device) for k, v in fresh_flow.state_dict().items()})
if self._ontology_matrix_numpy is not None:
# Re-register after meta-device materialization: persistent=False buffers are not
# restored by the state dict, and to_empty() leaves them as uninitialised storage.
self.register_buffer(
"ontology_distance_matrix",
torch.tensor(self._ontology_matrix_numpy, device=self.px_r.device),
persistent=False,
)
def batch_embedding_distribution(self, batch_index_n: torch.Tensor) -> Distribution:
assert self.batch_representation_mean_bd is not None
assert self.batch_representation_std_unconstrained_bd is not None
return Normal(
self.batch_representation_mean_bd[batch_index_n.long(), :],
self.batch_representation_std_unconstrained_bd[batch_index_n.long(), :].exp() + 1e-5,
)
[docs]
def batch_representation_from_batch_index(
self,
batch_index_n: torch.Tensor,
use_mean_though_sampling: bool = False,
) -> torch.Tensor:
"""Compute a batch representation from batch indices.
If self.batch_embedded is False, the batch representation will be one-hot (like scvi-tools)
If self.batch_embedded is True:
If self.batch_representation_sampled is True, the batch representation is sampled from a normal distribution
If self.batch_representation_sampled is False, the batch representation is a point estimate
"""
if not self.batch_embedded:
batch_nb = torch.nn.functional.one_hot(batch_index_n.squeeze().long(), num_classes=self.n_batch).float()
else:
if self.batch_representation_sampled:
if use_mean_though_sampling:
batch_nb = self.batch_embedding_distribution(batch_index_n=batch_index_n).mean
else:
batch_nb = self.batch_embedding_distribution(batch_index_n=batch_index_n).rsample()
else:
assert self.batch_representation_mean_bd is not None
batch_nb = self.batch_representation_mean_bd[batch_index_n.long(), :]
return batch_nb
[docs]
def categorical_onehot_from_categorical_index(
self,
categorical_covariate_index_nd: torch.Tensor | None,
) -> torch.Tensor | None:
"""Compute one-hot encoding of categorical covariates from integer category indices.
Args:
categorical_covariate_index_nd: a tensor of shape (n, n_categorical_covariates)
"""
if categorical_covariate_index_nd is not None:
# make the categorical covariates one-hot
categorical_covariate_np = torch.cat(
[
torch.nn.functional.one_hot(categorical_covariate_index_nd[:, i].long(), num_classes=n_cats).float()
for i, n_cats in enumerate(self.n_cats_per_cov)
],
dim=1,
)
return categorical_covariate_np
return None
[docs]
def inference(
self,
x_ng: torch.Tensor,
batch_nb: torch.Tensor,
continuous_covariates_nc: torch.Tensor | None = None,
categorical_covariate_np: torch.Tensor | None = None,
):
"""
High level inference method.
Runs the inference (encoder) model.
"""
encoder_input_ng = x_ng
if self.log_variational:
encoder_input_ng = torch.log1p(encoder_input_ng)
qz = self.z_encoder(x_ng=encoder_input_ng, batch_nb=batch_nb, categorical_covariate_np=categorical_covariate_np)
z = qz.rsample()
if self.use_flow:
assert self.flow is not None # for mypy
# Build the flow prior distribution for this forward pass.
# self.flow() returns a NormalizingFlow with event_shape=(n_latent,) and no context.
pz = self.flow()
return dict(z=z, qz=qz, pz=pz)
return dict(z=z, qz=qz)
[docs]
def generative(
self,
z_nk: torch.Tensor,
library_size_n1: torch.Tensor,
batch_nb: torch.Tensor,
continuous_covariates_nc: torch.Tensor | None = None,
categorical_covariate_np: torch.Tensor | None = None,
) -> dict[str, Distribution]:
"""Runs the generative model."""
inverse_overdispersion: torch.Tensor | None
match self.dispersion:
case "gene":
inverse_overdispersion = self.px_r.exp()
case "gene-cell":
inverse_overdispersion = None
case "gene-batch":
inverse_overdispersion = torch.nn.functional.linear(
batch_nb,
self.px_r,
).exp()
case "gene-label":
inverse_overdispersion = None
raise NotImplementedError
count_distribution = self.decoder(
z_nk=z_nk,
batch_nb=batch_nb,
categorical_covariate_np=categorical_covariate_np,
inverse_overdispersion=inverse_overdispersion,
library_size_n1=library_size_n1,
)
# prior on latent z
pz = Normal(torch.zeros_like(z_nk), torch.ones_like(z_nk))
return dict(px=count_distribution, pz=pz)
[docs]
def forward(
self,
x_ng: torch.Tensor,
var_names_g: np.ndarray,
batch_index_n: torch.Tensor,
continuous_covariates_nc: torch.Tensor | None = None,
categorical_covariate_index_nd: torch.Tensor | None = None,
total_mrna_umis_n: torch.Tensor | None = None,
):
"""
Args:
x_ng:
Gene counts matrix.
var_names_g:
The list of the variable names in the input data.
batch_index_n:
Batch indices of input cells as integers.
continuous_covariates_nc:
Continuous covariates for each cell (c-dimensional).
categorical_covariate_index_nd:
Categorical covariates for each cell (d-dimensional). Integer membership categorical codes.
total_mrna_umis_n:
Total mRNA UMIs for each cell (not log scaled) if this should be used.
Returns:
A dictionary with keys:
- "loss": The total loss value.
- "reconstruction_loss": The reconstruction loss value.
- "kl_divergence_z": The KL divergence for the latent variable z.
- "z_nk": The latent variable z.
"""
assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g)
assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g)
batch_nb = self.batch_representation_from_batch_index(batch_index_n)
categorical_covariate_np = self.categorical_onehot_from_categorical_index(categorical_covariate_index_nd)
if self.use_size_factor_key:
assert total_mrna_umis_n is not None, "total_mrna_umis_n must be provided when use_size_factor_key=True"
library_size_n1 = torch.log(total_mrna_umis_n).unsqueeze(-1)
else:
library_size_n1 = torch.log(x_ng.sum(dim=-1, keepdim=True))
if self.input_gene_dropout_rate > 0.0:
# randomly drop out some genes in the input data to the encoder during training
with torch.no_grad():
dropout_mask_ng = torch.rand_like(x_ng) > self.input_gene_dropout_rate
inference_input_x_ng = x_ng * dropout_mask_ng
else:
inference_input_x_ng = x_ng
inference_outputs = self.inference(
x_ng=inference_input_x_ng,
batch_nb=batch_nb,
continuous_covariates_nc=continuous_covariates_nc,
categorical_covariate_np=categorical_covariate_np,
)
generative_outputs = self.generative(
z_nk=inference_outputs["z"],
library_size_n1=library_size_n1,
batch_nb=batch_nb,
continuous_covariates_nc=continuous_covariates_nc,
categorical_covariate_np=categorical_covariate_np,
)
kl_annealing_weight = compute_annealed_kl_weight(
epoch=self.epoch,
step=self.step,
n_epochs_kl_warmup=self.kl_warmup_epochs,
n_steps_kl_warmup=self.kl_warmup_steps,
max_kl_weight=1.0,
min_kl_weight=self.kl_annealing_start,
)
# KL divergence for z
if self.use_flow:
# MC estimate: KL(q(z|x) || p_flow(z)) = E_q[log q(z) - log p_flow(z)]
# qz is Normal with per-dim log_probs; sum over latent dim to get per-cell scalar.
# self.flow() (NormalizingFlow) already reduces the event dimension, giving shape (n,).
z_nk = inference_outputs["z"]
log_qz_n = inference_outputs["qz"].log_prob(z_nk).sum(dim=-1)
# nan_to_num guards against spline instability producing NaN / ±inf in log_pz.
log_pz_n = torch.nan_to_num(inference_outputs["pz"].log_prob(z_nk), nan=0.0, posinf=0.0, neginf=-1e4)
kl_divergence_z_n = log_qz_n - log_pz_n
else:
kl_divergence_z_n = kl(inference_outputs["qz"], generative_outputs["pz"]).sum(dim=1)
# optional KL divergence for batch representation
kl_divergence_batch_n: torch.Tensor | int
if self.batch_representation_sampled and (self.batch_kl_weight_max > 0):
kl_divergence_batch_n = kl(
self.batch_embedding_distribution(batch_index_n=batch_index_n),
Normal(torch.zeros_like(batch_nb), torch.ones_like(batch_nb)),
).sum(dim=1)
else:
kl_divergence_batch_n = 0
# reconstruction loss
rec_loss_n = -generative_outputs["px"].log_prob(x_ng).sum(-1)
# full loss
assert kl_annealing_weight >= 0.0 and kl_annealing_weight <= 1.0, (
f"Invalid KL annealing weight: {kl_annealing_weight}"
)
loss = torch.mean(
rec_loss_n
+ kl_annealing_weight
* (self.z_kl_weight_max * kl_divergence_z_n + self.batch_kl_weight_max * kl_divergence_batch_n),
dim=0,
)
return {
"loss": loss,
"reconstruction_loss": rec_loss_n,
"kl_divergence_z": kl_divergence_z_n,
"kl_divergence_batch": kl_divergence_batch_n,
"z_nk": inference_outputs["z"],
}
def _latent_value_from_latent_distribution(self, d: Distribution) -> torch.Tensor:
"""Get the latent variable from the latent distribution."""
return d.mean
[docs]
def predict(
self,
x_ng: torch.Tensor,
var_names_g: np.ndarray,
batch_index_n: torch.Tensor,
continuous_covariates_nc: torch.Tensor | None = None,
categorical_covariate_index_nd: torch.Tensor | None = None,
):
"""
Args:
x_ng:
Gene counts matrix.
var_names_g:
The list of the variable names in the input data.
batch_index_n:
Batch indices of input cells as integers.
continuous_covariates_nc:
Continuous covariates for each cell (c-dimensional).
categorical_covariate_index_nd:
Categorical covariates for each cell (d-dimensional where d is number of categorical variables).
Values are integer membership categorical codes.
Returns:
A dictionary with the following keys:
- ``x_ng``:
- If :attr:`self.reconstruct_counts_on_predict` is False:
- (x_ng is a notational misnomer) Embedding of the input data into the scVI latent space,
typically referred to as ``z_nk``.
- If :attr:`self.reconstruct_counts_on_predict` is True:
- (x_ng is a notational misnomer) Reconstruction of the input data: ``x_ng'``, which
may have a different number of genes depending on :attr:`self.reconstruction_var_names_g`.
"""
if self.reconstruct_counts_on_predict:
assert self.reconstruction_var_names_g is not None
try:
gene_inds = [np.where(var_names_g == gid)[0][0] for gid in self.reconstruction_var_names_g]
except IndexError:
raise ValueError(
f"Some genes to reconstruct ({len(set(self.reconstruction_var_names_g) - set(var_names_g))}) "
f"are missing from the input data: {set(self.reconstruction_var_names_g) - set(var_names_g)}"
)
return self.reconstruct(
x_ng=x_ng,
var_names_g=var_names_g,
gene_inds=gene_inds,
batch_index_n=batch_index_n,
continuous_covariates_nc=continuous_covariates_nc,
categorical_covariate_index_nd=categorical_covariate_index_nd,
transform_batch=self.reconstruction_transform_batch,
transform_categorical_covariates=self.reconstruction_transform_categorical_covariates,
n_latent_samples=self.reconstruction_n_latent_samples,
use_importance_sampling=self.reconstruction_use_importance_sampling,
use_latent_mean=self.reconstruction_use_latent_mean,
reconstructed_library_size=self.reconstructed_library_size,
)
assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g)
assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g)
batch_nb = self.batch_representation_from_batch_index(batch_index_n)
categorical_covariate_np = self.categorical_onehot_from_categorical_index(categorical_covariate_index_nd)
z_nk = self._latent_value_from_latent_distribution(
self.inference(
x_ng=x_ng,
batch_nb=batch_nb,
continuous_covariates_nc=continuous_covariates_nc,
categorical_covariate_np=categorical_covariate_np,
)["qz"]
)
return {"x_ng": z_nk}
[docs]
@torch.no_grad()
def reconstruct(
self,
x_ng: torch.Tensor,
var_names_g: np.ndarray,
gene_inds: np.ndarray | list[int],
batch_index_n: torch.Tensor,
continuous_covariates_nc: torch.Tensor | None = None,
categorical_covariate_index_nd: torch.Tensor | None = None,
transform_batch: str | int | None = None,
transform_categorical_covariates: list[int] | None = None,
use_latent_mean: bool = False,
n_latent_samples: int = 1000,
use_importance_sampling: bool = False,
reconstructed_library_size: float = 10_000,
):
"""
Reconstruct the data using the VAE, optionally transforming the batch.
Note: scvi-tools uses the following strategy -
- for each transform_batch, put the data through the encoder (no dropout)
- sample n_latent_samples times to get several z values
- take the mean of the generative distribution
- obtain tensor shape [n_transform_batches, n_latent_samples, n_cells, n_genes]
- take a mean over the batches dimension
(- they optionally use importance sampling based on sampled z likelihoods)
- take a (weighted?) mean over the n_latent_samples dimension
Args:
x_ng: Gene counts matrix.
var_names_g: The list of the variable names in the input data.
gene_inds: The indices of the genes from var_names_g to be reconstructed. Output order preserves this order.
batch_index_n: Batch indices of input cells as integers.
continuous_covariates_nc: Continuous covariates for each cell (c-dimensional).
categorical_covariate_index_nd: Categorical covariates for each cell (d-dimensional where d is the number
of categorical variables). Used for the encoder; also used for the decoder when
transform_categorical_covariates is None.
transform_batch: If not None, transform the batch to this index before reconstruction.
transform_categorical_covariates: A list of integer category indices, one per categorical covariate
(in the same order as n_cats_per_cov), to fix for all cells during decoding. Must be supplied when
transform_batch is not None and the decoder uses categorical covariates — use
enumerate_observed_batch_covariate_combinations() to identify valid combinations present in the
training data. If None, the per-cell observed categorical covariates are passed to the decoder.
use_latent_mean: If True, use the mean of the latent distribution instead of sampling.
n_latent_samples: The number of latent samples to use for reconstruction.
use_importance_sampling: True to use importance sampling for the reconstruction, weighting each sample
of the latent by its likelihood.
reconstructed_library_size: The library size to use for the reconstruction, common to all cells.
Returns:
A dictionary with the following keys:
- ``x_ng``: Model's reconstruction of the input data, possibly de-batched. The notational misnomer
here is that "g" no longer stands for all genes, but the genes in `gene_inds`.
"""
assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g)
assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g)
if transform_batch is None:
# make this a list of size one with the measured values as default: an actual reconstruction
transformed_batch_index_n_list = [batch_index_n]
else:
transformed_batch_index_n_list = []
if isinstance(transform_batch, str):
if transform_batch != "mean":
raise ValueError(
'If transform_batch is a string, it must be "mean" which '
"will project counts into each batch and compute the mean, "
"otherwise specify a particular batch using its integer index"
)
for i in range(self.n_batch)[:10]:
transformed_batch_index_n_list.append(torch.ones_like(batch_index_n) * i)
else:
if transform_batch >= self.n_batch:
raise ValueError(f"transform_batch must be less than self.n_batch: {self.n_batch}")
transformed_batch_index_n_list = [torch.ones_like(batch_index_n) * transform_batch]
# enforce that transform_categorical_covariates is supplied when the decoder uses categorical covariates
# and transform_batch is set — otherwise the decoder sees per-cell observed covariates, not the target
# condition, and the reconstruction will still reflect per-cell batch identity.
if (
(transform_batch is not None)
and self._decoder_uses_categorical_covariates
and (transform_categorical_covariates is None)
):
raise ValueError(
"This model's decoder uses categorical covariates, but transform_categorical_covariates was not "
"supplied. When transform_batch is not None, transform_categorical_covariates must also be provided "
"so that all cells are decoded under the same (real-world) batch+categoricals condition. "
"Use enumerate_observed_batch_covariate_combinations() to identify valid (batch, covariate) "
"combinations that were present in the training data, then supply the desired covariate indices "
"as reconstruction_transform_categorical_covariates in the model config."
)
batch_nb = self.batch_representation_from_batch_index(batch_index_n)
categorical_covariate_np = self.categorical_onehot_from_categorical_index(categorical_covariate_index_nd)
library_size_n1 = torch.log(x_ng.sum(dim=-1, keepdim=True))
# generate the latent distribution
inference_outputs = self.inference(
x_ng=x_ng,
batch_nb=batch_nb,
continuous_covariates_nc=continuous_covariates_nc,
categorical_covariate_np=categorical_covariate_np,
)
def _run_generative_and_scale_output(
z_nk: torch.Tensor,
local_batch_nb: torch.Tensor,
local_categorical_covariate_np: torch.Tensor | None,
) -> torch.Tensor:
# use that latent sample and the transform batch to generate data
generative_outputs = self.generative(
z_nk=z_nk,
library_size_n1=library_size_n1,
batch_nb=local_batch_nb,
continuous_covariates_nc=continuous_covariates_nc,
categorical_covariate_np=local_categorical_covariate_np,
)
# take the mean of the distribution
counts_ng = generative_outputs["px"].mean
# normalize and re-scale
normalized_counts_ng = counts_ng / counts_ng.sum(dim=-1, keepdim=True)
scaled_counts_ng = reconstructed_library_size * normalized_counts_ng
# subset to genes of interest
scaled_counts_np = scaled_counts_ng[:, gene_inds]
return scaled_counts_np
x_tilde_np: torch.Tensor = torch.zeros(x_ng.shape[0], len(gene_inds), device=x_ng.device)
# go through each output batch projection (just one unless transform_batch == "mean")
for transformed_batch_index_n in transformed_batch_index_n_list:
output_counts_sum_np: int | torch.Tensor = 0
importance_weight_means: list = []
local_batch_nb = self.batch_representation_from_batch_index(
transformed_batch_index_n,
use_mean_though_sampling=True, # don't sample during reconstruction
)
# build the categorical covariate tensor for this reconstruction condition
if transform_categorical_covariates is not None:
fixed_idx_nd = (
torch.tensor(transform_categorical_covariates, dtype=torch.long, device=x_ng.device)
.unsqueeze(0)
.expand(x_ng.shape[0], -1)
)
local_categorical_covariate_np = self.categorical_onehot_from_categorical_index(fixed_idx_nd)
else:
local_categorical_covariate_np = categorical_covariate_np
if use_latent_mean:
# use the mean in the latent space
mean_z_nk = inference_outputs["qz"].mean
# run generative model and scale output
scaled_counts_np = _run_generative_and_scale_output(
z_nk=mean_z_nk,
local_batch_nb=local_batch_nb,
local_categorical_covariate_np=local_categorical_covariate_np,
)
output_counts_sum_np += scaled_counts_np
x_tilde_np = x_tilde_np + output_counts_sum_np
else:
for _ in range(n_latent_samples):
# take a sample from the latent space
sampled_z_nk = inference_outputs["qz"].sample()
mean_z_nk = inference_outputs["qz"].mean
# compute weight for importance sampling
if use_importance_sampling:
importance_weight_n1 = (
(
inference_outputs["qz"].log_prob(sampled_z_nk)
- inference_outputs["qz"].log_prob(mean_z_nk)
)
.sum(dim=-1)
.exp()
.unsqueeze(-1)
)
importance_weight_means.append(importance_weight_n1.squeeze().mean())
else:
importance_weight_n1 = 1.0
importance_weight_means.append(importance_weight_n1)
# run generative model and scale output
scaled_counts_np = _run_generative_and_scale_output(
z_nk=sampled_z_nk,
local_batch_nb=local_batch_nb,
local_categorical_covariate_np=local_categorical_covariate_np,
)
output_counts_sum_np += scaled_counts_np * importance_weight_n1
x_tilde_np = x_tilde_np + output_counts_sum_np / sum(importance_weight_means)
x_tilde_np = x_tilde_np / len(transformed_batch_index_n_list)
return {"x_ng": x_tilde_np}
# ------------------------------------------------------------------
# Validation hooks
# ------------------------------------------------------------------
def on_validation_epoch_start(self, trainer: pl.Trainer) -> None:
device = next(self.parameters()).device
# ELBO / reconstruction accumulators
self._val_elbo_sum = torch.zeros(1, device=device)
self._val_rec_sum = torch.zeros(1, device=device)
self._val_n_cells = torch.zeros(1, device=device)
# Ontology accumulators (only when num_classes is configured)
if self.num_classes is not None:
self._val_z_sum_kd = torch.zeros(self.num_classes, self.n_latent, device=device)
self._val_class_count_k = torch.zeros(self.num_classes, device=device)
# Batch silhouette accumulators
if self.n_batch > 1:
self._val_batch_z_sum_bk = torch.zeros(self.n_batch, self.n_latent, device=device)
self._val_batch_z_sq_sum_b = torch.zeros(self.n_batch, device=device)
self._val_batch_count_b = torch.zeros(self.n_batch, device=device)
# Cell type classifier reservoirs (train = even batch_idx, test = odd batch_idx)
if self.num_classes is not None:
rs = self.val_cell_type_classifier_reservoir_size
self._val_cl_train_z = torch.zeros(rs, self.n_latent, device=device)
self._val_cl_train_y = torch.zeros(rs, dtype=torch.long, device=device)
self._val_cl_train_fill = 0
self._val_cl_train_seen = 0
self._val_cl_test_z = torch.zeros(rs, self.n_latent, device=device)
self._val_cl_test_y = torch.zeros(rs, dtype=torch.long, device=device)
self._val_cl_test_fill = 0
self._val_cl_test_seen = 0
def validate(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
batch_idx: int,
x_ng: torch.Tensor,
var_names_g: np.ndarray,
batch_index_n: torch.Tensor,
continuous_covariates_nc: torch.Tensor | None = None,
categorical_covariate_index_nd: torch.Tensor | None = None,
total_mrna_umis_n: torch.Tensor | None = None,
validation_cell_type_index_n: torch.Tensor | None = None,
) -> None:
n = x_ng.shape[0]
output = self(
x_ng=x_ng,
var_names_g=var_names_g,
batch_index_n=batch_index_n,
continuous_covariates_nc=continuous_covariates_nc,
categorical_covariate_index_nd=categorical_covariate_index_nd,
total_mrna_umis_n=total_mrna_umis_n,
)
# log annealed loss for progress bar / on-step visibility
if isinstance(output["loss"], torch.Tensor):
pl_module.log("val_loss", output["loss"], sync_dist=True, on_epoch=True, batch_size=n)
# accumulate exact ELBO (no annealing)
kl_batch = output["kl_divergence_batch"]
if not isinstance(kl_batch, torch.Tensor):
kl_batch = torch.zeros(n, device=x_ng.device)
assert isinstance(output["reconstruction_loss"], torch.Tensor)
assert isinstance(output["kl_divergence_z"], torch.Tensor)
assert isinstance(output["z_nk"], torch.Tensor)
elbo_n = -(output["reconstruction_loss"] + output["kl_divergence_z"] + kl_batch)
self._val_elbo_sum += elbo_n.sum().detach()
self._val_rec_sum += output["reconstruction_loss"].sum().detach()
self._val_n_cells += n
z_nk = output["z_nk"].detach()
# accumulate per-class latent sums for ontology metric
if validation_cell_type_index_n is not None and self.num_classes is not None:
idx = validation_cell_type_index_n.long()
self._val_z_sum_kd.index_add_(0, idx, z_nk)
self._val_class_count_k.index_add_(0, idx, torch.ones(n, device=x_ng.device))
# accumulate per-batch latent sums for silhouette metric
if self.n_batch > 1:
bidx = batch_index_n.long()
self._val_batch_z_sum_bk.index_add_(0, bidx, z_nk)
self._val_batch_z_sq_sum_b.index_add_(0, bidx, (z_nk**2).sum(dim=-1))
self._val_batch_count_b.index_add_(0, bidx, torch.ones(n, device=x_ng.device))
# reservoir sampling for cell type classifier
if validation_cell_type_index_n is not None and self.num_classes is not None:
if batch_idx % 2 == 0:
buf_z, buf_y = self._val_cl_train_z, self._val_cl_train_y
fill, seen = self._val_cl_train_fill, self._val_cl_train_seen
else:
buf_z, buf_y = self._val_cl_test_z, self._val_cl_test_y
fill, seen = self._val_cl_test_fill, self._val_cl_test_seen
rs = self.val_cell_type_classifier_reservoir_size
labels = validation_cell_type_index_n.long()
# phase 1: direct fill up to capacity
space = rs - fill
direct = min(space, n)
if direct > 0:
buf_z[fill : fill + direct] = z_nk[:direct]
buf_y[fill : fill + direct] = labels[:direct]
fill += direct
# phase 2: reservoir replacement for overflow cells
for i in range(direct, n):
j = int(torch.randint(0, seen + i - direct + 1, (1,)).item())
if j < rs:
buf_z[j] = z_nk[i]
buf_y[j] = labels[i]
seen += n
if batch_idx % 2 == 0:
self._val_cl_train_fill, self._val_cl_train_seen = fill, seen
else:
self._val_cl_test_fill, self._val_cl_test_seen = fill, seen
@staticmethod
def _weighted_spearman(
latent_dists: torch.Tensor,
onto_dists: torch.Tensor,
counts: torch.Tensor,
) -> float:
import scipy.stats
m = latent_dists.shape[0]
if m < 2:
return float("nan")
idx = torch.triu_indices(m, m, offset=1)
x = latent_dists[idx[0], idx[1]].cpu().float().numpy()
y = onto_dists[idx[0], idx[1]].cpu().float().numpy()
w = (counts[idx[0]] * counts[idx[1]]).cpu().float().numpy()
# exclude disconnected pairs (inf ontology distance — different DAG subtrees)
finite_mask = np.isfinite(y)
if finite_mask.sum() < 3:
return float("nan")
x, y, w = x[finite_mask], y[finite_mask], w[finite_mask]
rx = scipy.stats.rankdata(x)
ry = scipy.stats.rankdata(y)
cov = np.cov(rx, ry, aweights=w)
denom = float(np.sqrt(max(0.0, cov[0, 0] * cov[1, 1])))
if (denom == 0.0) or np.isnan(denom):
logger.warning(
"val_ontology_spearman: zero variance in ontology distances (all pairs at equal distance); logging 0.0"
)
return 0.0
return float(cov[0, 1] / denom)
def on_validation_epoch_end(self, lightning_module: pl.LightningModule, trainer: pl.Trainer) -> None:
if trainer.sanity_checking:
return
# sync all accumulators across DDP ranks
if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
if world_size > 1:
tensors_to_sync = [self._val_elbo_sum, self._val_rec_sum, self._val_n_cells]
if self.num_classes is not None:
tensors_to_sync += [self._val_z_sum_kd, self._val_class_count_k]
if self.n_batch > 1:
tensors_to_sync += [
self._val_batch_z_sum_bk,
self._val_batch_z_sq_sum_b,
self._val_batch_count_b,
]
for t in tensors_to_sync:
torch.distributed.all_reduce(t, op=torch.distributed.ReduceOp.SUM)
# gather classifier reservoirs: exchange fill counts then concatenate valid slices
if self.num_classes is not None:
for buf_z, buf_y, fill_attr, seen_attr in [
(self._val_cl_train_z, self._val_cl_train_y, "_val_cl_train_fill", "_val_cl_train_seen"),
(self._val_cl_test_z, self._val_cl_test_y, "_val_cl_test_fill", "_val_cl_test_seen"),
]:
fill = getattr(self, fill_attr)
fill_t = torch.tensor([fill], device=buf_z.device)
fills = [torch.zeros_like(fill_t) for _ in range(world_size)]
torch.distributed.all_gather(fills, fill_t)
gathered_z = [torch.zeros_like(buf_z) for _ in range(world_size)]
gathered_y = [torch.zeros_like(buf_y) for _ in range(world_size)]
torch.distributed.all_gather(gathered_z, buf_z)
torch.distributed.all_gather(gathered_y, buf_y)
combined_z = torch.cat([g[: int(f.item())] for g, f in zip(gathered_z, fills)], dim=0)
combined_y = torch.cat([g[: int(f.item())] for g, f in zip(gathered_y, fills)], dim=0)
# apply reservoir sub-sampling if combined exceeds capacity
rs = self.val_cell_type_classifier_reservoir_size
if combined_z.shape[0] > rs:
perm = torch.randperm(combined_z.shape[0], device=combined_z.device)[:rs]
combined_z = combined_z[perm]
combined_y = combined_y[perm]
buf_z[: combined_z.shape[0]] = combined_z
buf_y[: combined_y.shape[0]] = combined_y
setattr(self, fill_attr, combined_z.shape[0])
# only rank 0 computes and logs the epoch-level metrics
if trainer.global_rank != 0:
return
if trainer.logger is None:
return
step = trainer.global_step
# ELBO and reconstruction loss
n_cells = self._val_n_cells.item()
if n_cells > 0:
val_elbo = (self._val_elbo_sum / self._val_n_cells).item()
trainer.logger.log_metrics({"val_elbo": val_elbo}, step=step)
val_rec = self._val_rec_sum / self._val_n_cells
lightning_module.log(
name="val_reconstruction_loss",
value=val_rec.detach(),
sync_dist=True,
)
# ontology-weighted Spearman
if self.num_classes is not None and self.ontology_distance_matrix is not None:
valid_mask = self._val_class_count_k > 0
n_valid = int(valid_mask.sum().item())
if n_valid >= 2:
counts_valid = self._val_class_count_k[valid_mask]
centroids = self._val_z_sum_kd[valid_mask] / counts_valid.unsqueeze(1)
latent_dists = torch.cdist(centroids, centroids)
onto_dists = self.ontology_distance_matrix[valid_mask][:, valid_mask]
spearman = self._weighted_spearman(latent_dists, onto_dists, counts_valid)
if not np.isnan(spearman):
trainer.logger.log_metrics({"val_ontology_spearman": spearman}, step=step)
# centroid-based batch silhouette
if self.n_batch > 1:
valid_mask = self._val_batch_count_b > 0
n_valid = int(valid_mask.sum().item())
if n_valid >= 2:
counts_b = self._val_batch_count_b[valid_mask]
centroids_bk = self._val_batch_z_sum_bk[valid_mask] / counts_b.unsqueeze(1)
mean_sq_b = self._val_batch_z_sq_sum_b[valid_mask] / counts_b
# within-batch spread: sqrt(E[||z||^2] - ||centroid||^2)
intra_b = (mean_sq_b - centroids_bk.pow(2).sum(dim=-1)).clamp(min=0.0).sqrt()
D = torch.cdist(centroids_bk, centroids_bk)
# mask self-distances
D.fill_diagonal_(float("inf"))
b_b = D.min(dim=1).values # nearest other centroid distance
s_b = (b_b - intra_b) / torch.max(b_b, intra_b).clamp(min=1e-12)
val_silhouette = (s_b * counts_b).sum() / counts_b.sum()
trainer.logger.log_metrics({"val_batch_silhouette": val_silhouette.item()}, step=step)
# cell type logistic regression classifier
if self.num_classes is not None and self._val_cl_train_fill >= 2 and self._val_cl_test_fill >= 1:
X_train = self._val_cl_train_z[: self._val_cl_train_fill].float()
y_train = self._val_cl_train_y[: self._val_cl_train_fill]
X_test = self._val_cl_test_z[: self._val_cl_test_fill].float()
y_test = self._val_cl_test_y[: self._val_cl_test_fill]
# normalize features for stable LBFGS convergence
mu = X_train.mean(0)
sigma = X_train.std(0).clamp(min=1e-8)
X_train = (X_train - mu) / sigma
X_test = (X_test - mu) / sigma
W = torch.zeros(self.num_classes, self.n_latent, device=X_train.device, requires_grad=True)
b = torch.zeros(self.num_classes, device=X_train.device, requires_grad=True)
opt = torch.optim.LBFGS([W, b], max_iter=200, line_search_fn="strong_wolfe")
def closure():
opt.zero_grad()
loss = torch.nn.functional.cross_entropy(X_train @ W.T + b, y_train)
loss.backward()
return loss
with torch.enable_grad():
opt.step(closure)
with torch.no_grad():
logits_test = X_test @ W.T + b
preds = logits_test.argmax(dim=-1)
top1 = (preds == y_test).float().mean().detach()
lightning_module.log(
name="val_cell_type_top1_accuracy",
value=top1,
sync_dist=True,
)
if self.ontology_distance_matrix is not None:
wrong_mask = preds != y_test
if wrong_mask.any():
err_dist = self.ontology_distance_matrix[preds[wrong_mask], y_test[wrong_mask]]
finite = err_dist.isfinite()
if finite.any():
mean_err = err_dist[finite].mean().item()
trainer.logger.log_metrics({"val_cell_type_mean_error_distance": mean_err}, step=step)
def on_train_batch_end(self, trainer: pl.Trainer) -> None:
self.step = trainer.global_step
self.epoch = trainer.current_epoch
# log these values to pytorch lightning logger
if trainer.logger is not None:
trainer.logger.log_metrics(
{
"kl_annealing_weight": compute_annealed_kl_weight(
epoch=self.epoch,
step=self.step,
n_epochs_kl_warmup=self.kl_warmup_epochs,
n_steps_kl_warmup=self.kl_warmup_steps,
max_kl_weight=1.0,
min_kl_weight=self.kl_annealing_start,
)
},
step=trainer.global_step,
)
def batch_index_to_batch_label(adata: AnnData, batch_keys: list[str]) -> pd.DataFrame:
"""
Convert integer batch index used in the model to a human-readable batch label.
Args:
adata: AnnData object. Can be any individual shard as long as categoricals contain all categories.
batch_keys: List of batch keys.
Returns:
DataFrame with columns as batch covariates and an extra column "scvi_batch_code" with the
code used in the model.
"""
logger.warning("The batch_index_to_batch_label lookup for multiple batch_keys is still experimental.")
df = _enumerate_categorical_combinations(adata.obs[batch_keys])
df["scvi_batch_code"] = categories_to_product_codes(df)
return df
def enumerate_observed_batch_covariate_combinations(
obs: pd.DataFrame,
batch_keys: list[str],
categorical_covariate_keys: list[str] | None = None,
) -> pd.DataFrame:
"""
Tabulate every observed combination of (batch, categorical covariates) in an AnnData obs DataFrame,
along with the integer indices that the model uses internally for each field.
This is intended to help users choose valid values for ``reconstruction_transform_batch`` and
``reconstruction_transform_categorical_covariates`` when running scVI's reconstruct method.
The model requires that the specified (batch, covariate) combination was actually present in the
training data; this function makes it easy to inspect which combinations exist and how many cells
belong to each.
Args:
obs: The ``adata.obs`` DataFrame. All columns listed in ``batch_keys`` and
``categorical_covariate_keys`` must be pandas Categoricals.
batch_keys: The obs column name(s) used as the batch key during training (matching the
``batch_index_n`` field in the datamodule config). The integer ``batch_index`` reported
here is the value passed to the model as ``batch_index_n``, computed as
``categories_to_product_codes`` over these columns.
categorical_covariate_keys: The obs column name(s) used as categorical covariates during
training (matching the ``categorical_covariate_index_nd`` field in the datamodule config),
in the same order. For each key, the reported ``{key}_index`` is the 0-based integer code
within that covariate's categories — i.e., the value to put in position ``i`` of
``reconstruction_transform_categorical_covariates``. If None, no covariate columns are
included.
Returns:
A DataFrame with one row per observed (batch, covariate) combination, sorted by ``cell_count``
descending. Columns are:
- One label column per batch key (human-readable category label).
- ``batch_index``: the integer passed to the model as ``batch_index_n``.
- One label column per categorical covariate key (human-readable category label).
- One ``{key}_index`` column per categorical covariate key (0-based integer code within
that covariate's categories, for use as ``reconstruction_transform_categorical_covariates[i]``).
- ``cell_count``: number of cells with that combination.
"""
work = obs[batch_keys].copy()
for k in batch_keys:
if not hasattr(work[k], "cat"):
raise ValueError(f"batch_keys column '{k}' must be a pandas Categorical.")
# compute batch_index using the same logic as the training pipeline
batch_index_series = pd.Series(
categories_to_product_codes(work[batch_keys] if len(batch_keys) > 1 else work[batch_keys[0]]),
index=obs.index,
name="batch_index",
)
work["batch_index"] = batch_index_series
cov_index_cols: list[str] = []
if categorical_covariate_keys is not None:
for k in categorical_covariate_keys:
if not hasattr(obs[k], "cat"):
raise ValueError(f"categorical_covariate_keys column '{k}' must be a pandas Categorical.")
idx_col = f"{k}_index"
work[k] = obs[k]
work[idx_col] = obs[k].cat.codes.values
cov_index_cols.append(idx_col)
# group by label columns + index columns; cell_count is the group size
group_cols = batch_keys + ["batch_index"]
if categorical_covariate_keys is not None:
for k, idx_col in zip(categorical_covariate_keys, cov_index_cols):
group_cols += [k, idx_col]
result = (
work.groupby(group_cols, observed=True, sort=False)
.size()
.reset_index(name="cell_count")
.sort_values("cell_count", ascending=False)
.reset_index(drop=True)
)
return result
def _n_cats_per_column(df: pd.DataFrame) -> list[int]:
"""
Return the number of categories for each column in a DataFrame, assuming all columns are categorical.
"""
n_cats_per_col = []
for key in df.columns:
covariate_series = df[key]
n_cats_per_col.append(len(covariate_series.cat.categories))
return n_cats_per_col
def _enumerate_categorical_combinations(df: pd.DataFrame) -> pd.DataFrame:
"""
Enumerate all possible combinations of categories in a DataFrame of categorical covariates.
"""
categories_per_column = _n_cats_per_column(df)
# Generate the range of values for each column based on the number of categories
category_ranges = [range(c) for c in categories_per_column]
# Use itertools.product to get all possible combinations
combinations = list(itertools.product(*category_ranges))
# Convert to DataFrame for easy handling and return
enumerated_df = pd.DataFrame(combinations, columns=df.columns)
for c in df.columns:
lookup = dict(zip(range(len(df[c].cat.categories)), df[c].cat.categories))
enumerated_df[c] = enumerated_df[c].map(lookup)
enumerated_df[c] = enumerated_df[c].astype("category")
enumerated_df[c] = enumerated_df[c].cat.set_categories(df[c].cat.categories)
return enumerated_df