Models

class cellarium.ml.models.CellariumGPT(categorical_token_size_dict: dict[str, int], d_model: int, d_ffn: int, n_heads: int, n_blocks: int, dropout_p: float, use_bias: bool, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, loss_scale_dict: dict[str, float], initializer_range: float = 0.02, embeddings_scale: float = 1.0, attention_logits_scale: float = 1.0, output_logits_scale: float = 1.0, mup_base_d_model: int | None = None, mup_base_d_ffn: int | None = None)[source]

Bases: CellariumModel, PredictMixin, ValidateMixin

CellariumGPT model.

Parameters:
  • categorical_token_size_dict (dict[str, int]) – Categorical token vocabulary sizes. Must include “gene_value” and “gene_id”. Additionally, it can include experimental conditions, such as “assay” and “suspension_type”, and metadata tokens such as “cell_type”, “tissue”, “sex”, “development_stage”, and “disease”.

  • d_model (int) – Dimensionality of the embeddings and hidden states.

  • d_ffn (int) – Dimensionality of the inner feed-forward layers.

  • n_heads (int) – Number of attention heads.

  • n_blocks (int) – Number of transformer blocks.

  • dropout_p (float) – Dropout probability.

  • use_bias (bool) – Whether to use bias in the linear transformations.

  • attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.

  • attention_softmax_fp32 (bool) – Whether to use float32 for softmax computation when torch backend is used.

  • loss_scale_dict (dict[str, float]) – A dictionary of loss scales for each label type. These are the query tokens that are used to compute the loss.

  • initializer_range (float) – The standard deviation of the truncated normal initializer.

  • embeddings_scale (float) – Multiplier for the embeddings.

  • attention_logits_scale (float) – Multiplier for the attention logits.

  • output_logits_scale (float) – Multiplier for the output logits.

  • mup_base_d_model (int | None) – Base dimensionality of the model for muP.

  • mup_base_d_ffn (int | None) – Base dimensionality of the inner feed-forward layers for muP.

predict(token_value_nc_dict: dict[str, Tensor], token_mask_nc_dict: dict[str, Tensor], prompt_mask_nc: Tensor) dict[str, ndarray | Tensor][source]
Parameters:
  • token_value_nc_dict (dict[str, Tensor]) – Dictionary of token value tensors of shape (n, c).

  • token_mask_nc_dict (dict[str, Tensor]) – Dictionary of token mask tensors of shape (n, c).

  • prompt_mask_nc (Tensor)

Returns:

Dictionary of logits tensors of shape (n, c, k).

Return type:

dict[str, ndarray | Tensor]

class cellarium.ml.models.CellariumModel[source]

Bases: Module

Base class for Cellarium ML compatible models.

abstractmethod reset_parameters() None[source]

Reset the model parameters and buffers that were constructed in the __init__ method. Constructed means methods like torch.tensor, torch.empty, torch.zeros, torch.randn, torch.as_tensor, etc. If the parameter or buffer was assigned (e.g. from a torch.Tensor passed to the __init__) then there is no need to reset it.

Return type:

None

class cellarium.ml.models.ContrastiveMLP(n_obs: int, hidden_size: Sequence[int], embed_dim: int, temperature: float = 1.0)[source]

Bases: CellariumModel, PredictMixin

Multilayer perceptron trained with contrastive learning.

Parameters:
  • n_obs (int) – Number of observations in each entry (network input size).

  • hidden_size (Sequence[int]) – Dimensionality of the fully-connected hidden layers.

  • embed_dim (int) – Size of embedding (network output size).

  • temperature (float) – Parameter governing Normalized Temperature-scaled cross-entropy (NT-Xent) loss.

forward(x_ng: Tensor) dict[str, Tensor][source]
Parameters:

x_ng (Tensor) – Gene counts matrix.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor]

predict(x_ng: Tensor)[source]

Sends (transformed) data through the model and returns outputs.

Parameters:

x_ng (Tensor) – Gene counts matrix.

Returns:

A dictionary with the embedding matrix.

class cellarium.ml.models.Geneformer(var_names_g: ndarray, hidden_size: int = 256, num_hidden_layers: int = 6, num_attention_heads: int = 4, intermediate_size: int = 512, hidden_act: str = 'relu', hidden_dropout_prob: float = 0.02, attention_probs_dropout_prob: float = 0.02, max_position_embeddings: int = 2048, type_vocab_size: int = 2, initializer_range: float = 0.02, position_embedding_type: str = 'absolute', layer_norm_eps: float = 1e-12, mlm_probability: float = 0.15)[source]

Bases: CellariumModel, PredictMixin

Geneformer model.

References:

  1. Transfer learning enables predictions in network biology (Theodoris et al.).

Parameters:
  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • hidden_size (int) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int) – Number of attention heads for each attention layer in the Transformer encoder.

  • intermediate_size (int) – Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.

  • hidden_act (str) – The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "silu" and "gelu_new" are supported.

  • hidden_dropout_prob (float) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_probs_dropout_prob (float) – The dropout ratio for the attention probabilities.

  • max_position_embeddings (int) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • type_vocab_size (int) – The vocabulary size of the token_type_ids passed when calling transformers.BertModel.

  • initializer_range (float) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • position_embedding_type (str) – Type of position embedding. Choose one of "absolute", "relative_key", "relative_key_query". For positional embeddings use "absolute". For more information on "relative_key", please refer to Self-Attention with Relative Position Representations (Shaw et al.). For more information on "relative_key_query", please refer to Method 4 in Improve Transformer Models with Better Relative Position Embeddings (Huang et al.).

  • layer_norm_eps (float) – The epsilon used by the layer normalization layers.

  • mlm_probability (float) – Ratio of tokens to mask for masked language modeling loss.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor]

predict(x_ng: Tensor, var_names_g: ndarray, output_hidden_states: bool = True, output_attentions: bool = True, output_input_ids: bool = True, output_attention_mask: bool = True, feature_activation: list[str] | None = None, feature_deletion: list[str] | None = None, feature_map: dict[str, int] | None = None) dict[str, Tensor | ndarray][source]

Send (transformed) data through the model and return outputs. Optionally perform in silico perturbations and masking.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

  • output_hidden_states (bool) – Whether to return all hidden-states.

  • output_attentions (bool) – Whether to return all attentions.

  • output_input_ids (bool) – Whether to return input ids.

  • output_attention_mask (bool) – Whether to return attention mask.

  • feature_activation (list[str] | None) – Specify features whose expression should be set to > max(x_ng) before tokenization (top rank).

  • feature_deletion (list[str] | None) – Specify features whose expression should be set to zero before tokenization (remove from inputs).

  • feature_map (dict[str, int] | None) – Specify a mapping for input tokens, to be applied before model.

Returns:

A dictionary with the inference results.

Return type:

dict[str, Tensor | ndarray]

Note

In silico perturbations can be achieved in one of three ways:

  1. Use feature_map to replace a feature token with MASK (1) or PAD (0) (e.g. feature_map={"ENSG0001": 1} will replace var_names_g feature ENSG0001 with a MASK token).

  2. Use feature_deletion to remove a feature from the cell’s inputs, which instead of adding a PAD or MASK token, will allow another feature to take its place (e.g. feature_deletion=["ENSG0001"] will remove var_names_g feature ENSG0001 from the input, and allow a new feature token to take its place).

  3. Use feature_activation to move a feature all the way to the top rank position in the input (e.g. feature_activation=["ENSG0001"] will make var_names_g feature ENSG0001 the first in rank order. Multiple input features will be ranked according to their order in the input list).

Number (2) and (3) are described in the Geneformer paper under “In silico perturbation” in the Methods section.

class cellarium.ml.models.HVGSeuratV3(var_names_g: ndarray, n_top_genes: int | list[int], n_batch: int = 1, flavor: Literal['seurat_v3', 'seurat_v3_paper'] = 'seurat_v3', use_batch_key: bool = False, span: float = 0.3, output_path: str | None = 'hvg_seurat_v3_output.csv', batch_n_cell_minimum: int = 2, n_batch_minimum: int = 1)[source]

Bases: CellariumModel

Compute highly variable genes using the Seurat v3 method in two Lightning epochs.

Epoch 0 — streams data to accumulate per-batch mean and variance.

Between epochs — fits a LOESS model of log10(var) ~ log10(mean) per batch to estimate a regularized standard deviation and per-cell clip value.

Epoch 1 — streams data again, clips counts per cell at the batch-level clip_val, and accumulates clipped sums.

After epoch 1 (on_train_epoch_end) — computes normalized variance per batch, ranks genes, and writes self.hvg_dfs (and optionally a CSV file at output_path).

The flavor argument controls how genes are ranked across batches when n_batch > 1, matching the two multi-batch modes in Scanpy:

  • "seurat_v3" (default) — sorts by highly_variable_rank ascending first, then by highly_variable_nbatches descending as tiebreaker. Genes with the lowest median rank across batches are preferred. This matches Scanpy’s flavor='seurat_v3'.

  • "seurat_v3_paper" — sorts by highly_variable_nbatches descending first, then by highly_variable_rank ascending as tiebreaker. Genes that are highly variable in more batches are preferred, regardless of their exact rank. This matches Scanpy’s flavor='seurat_v3_paper'.

Usage:

model = HVGSeuratV3(var_names_g=gene_names, n_top_genes=2000, n_batch=4,
                    flavor="seurat_v3_paper", span=0.3)
trainer = pl.Trainer(max_epochs=2)
trainer.fit(module, datamodule)
df = model.hvg_df  # pandas DataFrame, Scanpy-compatible columns
Parameters:
  • var_names_g (np.ndarray) – Array of gene names, length n_genes.

  • n_top_genes (int | list[int]) – Number of highly variable genes to select. Can be a list of ints to produce multiple gene sets in a single training run.

  • n_batch (int) – Number of batches (use 1 when no batch information is given).

  • flavor (Literal['seurat_v3', 'seurat_v3_paper']) – Multi-batch gene ranking strategy. "seurat_v3_paper" (default) prioritises batch consistency; "seurat_v3" prioritises median rank.

  • use_batch_key (bool) – Whether to expect a batch_index_n batch key in the dataloader output. If False, the model will ignore batch keys and treat all data as a single batch.

  • span (float) – LOESS span (fraction of data used per local fit). Default 0.3.

  • output_path (str | None) – If given, the result DataFrame is written to this CSV filepath after training. (Ends with .csv). If None, no file is written, but you really want to write this output, as it would require manually calling _compute_hvg_df() later.

  • batch_n_cell_minimum (int) – Minimum number of cells required for a batch to be considered valid.

  • n_batch_minimum (int) – Minimum number of batches in which the gene is in n_top_genes highly variable for a gene to make the final list.

on_train_epoch_start(trainer: Trainer) None[source]

Cache the current epoch number so forward() can branch on it.

Parameters:

trainer (Trainer)

Return type:

None

on_train_start(trainer: Trainer) None[source]

Validation: if in a distributed setting, use DDP with broadcast_buffers=False.

Parameters:

trainer (Trainer)

Return type:

None

on_train_epoch_end(trainer: Trainer) None[source]

Implement the two-step Seurat v3 HVG method using hooks at the end of first and second epochs. After epoch 0, reduce buffers and fit LOESS to set clip_val_bg and reg_std_bg. After epoch 1, reduce buffers and compute hvg_df.

Parameters:

trainer (Trainer)

Return type:

None

class cellarium.ml.models.IncrementalPCA(var_names_g: ndarray, n_components: int, svd_lowrank_niter: int = 2, perform_mean_correction: bool = True)[source]

Bases: CellariumModel, PredictMixin

Distributed and Incremental PCA.

References:

  1. A Distributed and Incremental SVD Algorithm for Agglomerative Data Analysis on Large Networks (Iwen et al.).

  2. Incremental Learning for Robust Visual Tracking (Ross et al.).

Parameters:
  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • n_components (int) – Number of principal components.

  • svd_lowrank_niter (int) – Number of iterations for the low-rank SVD algorithm.

  • perform_mean_correction (bool) – If True then the mean correction is applied to the update step. If False then the data is assumed to be centered and the mean correction is not applied to the update step.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]

Incrementally update partial SVD with new data.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

An empty dictionary.

Return type:

dict[str, Tensor | None]

on_train_epoch_end(trainer: Trainer) None[source]

Merge partial SVD results from parallel processes at the end of the epoch.

Merging SVDs is performed hierarchically. At each merging level, the leading process (even rank) merges its SVD with the trailing process (odd rank). The trailing process discards its SVD and is terminated. The leading process continues to the next level. This process continues until only one process remains. The final SVD is stored on the remaining process.

The number of levels (hierarchy depth) scales logarithmically with the number of processes.

Parameters:

trainer (Trainer)

Return type:

None

property explained_variance_k: Tensor

The amount of variance explained by each of the selected components. The variance estimation uses x_size degrees of freedom.

Equal to n_components largest eigenvalues of the covariance matrix of input data.

property components_kg: Tensor

Principal axes in feature space, representing the directions of maximum variance in the data. Equivalently, the right singular vectors of the centered input data, parallel to its eigenvectors. The components are sorted by decreasing explained_variance_k.

predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]

Centering and embedding of the input data x_ng into the principal component space.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

  • x_ng: Embedding of the input data into the principal component space.

  • var_names_g: The list of variable names for the output data.

Return type:

A dictionary with the following keys

class cellarium.ml.models.LogisticRegression(n_obs: int, var_names_g: ndarray, y_categories: ndarray, W_prior_scale: float = 1.0, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True)[source]

Bases: CellariumModel, PredictMixin, ValidateMixin

Logistic regression model.

Parameters:
  • n_obs (int) – Number of observations.

  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • y_categories (ndarray) – The categories for the target data.

  • W_prior_scale (float) – The scale of the Laplace prior for the weights.

  • W_init_scale (float) – Initialization scale for the W_gc parameter.

  • seed (int) – Random seed used to initialize parameters.

  • log_metrics (bool) – Whether to log the histogram of the W_gc parameter.

forward(x_ng: Tensor, var_names_g: ndarray, y_n: Tensor, y_categories: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – The input data.

  • var_names_g (ndarray) – The variable names for the input data.

  • y_n (Tensor) – The target data.

  • y_categories (ndarray) – The categories for the input target data.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor | None]

predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]

Predict the target logits.

Parameters:
  • x_ng (Tensor) – The input data.

  • var_names_g (ndarray) – The variable names for the input data.

Returns:

A dictionary with the target logits.

Return type:

dict[str, ndarray | Tensor]

class cellarium.ml.models.OnePassMeanVarStd(var_names_g: ndarray, algorithm: Literal['naive', 'shifted_data'] = 'naive', n_batch: int = 1, output_path: str | None = 'onepass_mean_var_std_output.csv')[source]

Bases: CellariumModel

Calculate the mean, variance, and standard deviation of the data in one pass (epoch) using running sums and running squared sums.

Tracks per-batch statistics. Use n_batch=1 when there is no meaningful batch structure. After training, batch_mean_bg and batch_var_bg give per-batch per-gene statistics suitable for passing to get_highly_variable_genes.

References:

  1. Algorithms for calculating variance.

Parameters:
  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • algorithm (Literal['naive', 'shifted_data']) – "naive" (default) or "shifted_data" (numerically stable).

  • n_batch (int) – Number of batches. Use 1 to reproduce ``batch_key``=None behavior.

  • output_path (str | None) – Optional path to save a summary CSV at the end of training, containing the mean and variance for each gene. If None, no CSV will be saved.

forward(x_ng: Tensor, var_names_g: ndarray, batch_index_n: Tensor | None = None) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – Variable names in the input data.

  • batch_index_n (Tensor | None) – Optional batch indices for each cell, required if n_batch > 1.

Returns:

An empty dictionary.

Return type:

dict[str, Tensor | None]

property batch_mean_bg: Tensor

Per-batch mean, shape (n_batch, n_genes).

property batch_var_bg: Tensor

Per-batch population variance, shape (n_batch, n_genes).

on_train_end(trainer: Trainer) None[source]

Save a summary output CSV so we need not load the checkpoint downstream.

Parameters:

trainer (Trainer)

Return type:

None

class cellarium.ml.models.PredictMixin[source]

Bases: object

Abstract mixin class for models that can perform prediction.

abstractmethod predict(*args: Any, **kwargs: Any) dict[str, ndarray | Tensor][source]

Perform prediction.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

dict[str, ndarray | Tensor]

class cellarium.ml.models.ProbabilisticPCA(n_obs: int, var_names_g: ndarray, n_components: int, ppca_flavor: Literal['marginalized', 'linear_vae'], mean_g: Tensor | None = None, W_init_scale: float = 1.0, sigma_init_scale: float = 1.0, seed: int = 0)[source]

Bases: CellariumModel, PredictMixin

Probabilistic PCA implemented in Pyro.

Two flavors of probabilistic PCA are available - marginalized pPCA [1] and linear VAE [2].

References:

  1. Probabilistic Principal Component Analysis (Tipping et al.).

  2. Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.).

Parameters:
  • n_obs (int) – Number of cells.

  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • n_components (int) – Number of principal components.

  • ppca_flavor (Literal['marginalized', 'linear_vae']) – Type of the PPCA model. Has to be one of marginalized or linear_vae.

  • mean_g (Tensor | None) – Mean gene expression of the input data. If None then the mean is set to a learnable parameter.

  • W_init_scale (float) – Scale of the random initialization of the W_kg parameter.

  • sigma_init_scale (float) – Initialization value of the sigma parameter.

  • seed (int) – Random seed used to initialize parameters.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor | None]

predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]

Centering and embedding of the input data x_ng into the principal component space.

Note

Gradients are disabled, used for inference only.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

  • z_nk: Embedding of the input data into the principal component space.

Return type:

A dictionary with the following keys

property L_k: Tensor

Vector with elements given by the PC eigenvalues.

Note

Gradients are disabled, used for inference only.

property U_gk: Tensor

Principal components corresponding to eigenvalues L_k.

Note

Gradients are disabled, used for inference only.

property W_variance: float

Note

Gradients are disabled, used for inference only.

property sigma_variance: float

Note

Gradients are disabled, used for inference only.

class cellarium.ml.models.TDigest(var_names_g: ndarray)[source]

Bases: CellariumModel

Compute an approximate non-zero histogram of the distribution of each gene in a batch of cells using t-digests.

References:

  1. Computing Extremely Accurate Quantiles Using T-Digests (Dunning et al.).

Parameters:

var_names_g (ndarray) – The variable names schema for the input data validation.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

An empty dictionary.

Return type:

dict[str, Tensor | None]

property median_g: Tensor

Median of the data.

class cellarium.ml.models.TestMixin[source]

Bases: object

Abstract mixin class for models that can perform testing.

abstractmethod test(trainer: Trainer, pl_module: LightningModule, batch_idx: int, *args: Any, **kwargs: Any) None[source]

Perform testing.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

  • batch_idx (int)

  • args (Any)

  • kwargs (Any)

Return type:

None

class cellarium.ml.models.ValidateMixin[source]

Bases: object

Mixin class for models that can perform validation.

validate(trainer: Trainer, pl_module: LightningModule, batch_idx: int, *args: Any, **kwargs: Any) None[source]

Default validation method for models. This method logs the validation loss to TensorBoard. Override this method to customize the validation behavior.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

  • batch_idx (int)

  • args (Any)

  • kwargs (Any)

Return type:

None

class cellarium.ml.models.SingleCellVariationalInference(var_names_g: Sequence[str], encoder: dict[str, list[dict] | dict | bool], decoder: dict[str, list[dict] | dict | bool], n_batch: int = 0, n_latent: int = 10, n_continuous_cov: int = 0, n_cats_per_cov: list[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'nb', latent_distribution: Literal['normal', 'ln'] = 'normal', batch_embedded: bool = False, batch_representation_sampled: bool = False, n_latent_batch: int | None = None, z_kl_weight_max: float = 1.0, batch_kl_weight_max: float = 0.0, input_gene_dropout_rate: float = 0.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', kl_warmup_epochs: int | None = 400, kl_warmup_steps: int | None = None, kl_annealing_start: float = 0.0, use_size_factor_key: bool = False, reconstruct_counts_on_predict: bool = False, reconstruction_var_names_g: ndarray | list | None = None, reconstruction_transform_batch: None | int | str = 0, reconstruction_n_latent_samples: int = 30, reconstruction_use_latent_mean: bool = False, reconstruction_use_importance_sampling: bool = False, reconstructed_library_size: int = 10000, reconstruction_transform_categorical_covariates: list[int] | None = None, use_flow: bool = False, flow_hidden_features: list[int] = [64, 64], cell_type_categories: list[str] | None = None, ontology_distance_matrix: DataFrame | None = None, val_cell_type_classifier_reservoir_size: int = 50000)[source]

Bases: CellariumModel, PredictMixin, ValidateMixin

Flexible version of single-cell variational inference (scVI) [1] re-implemented in Cellarium ML.

References:

  1. Deep generative modeling for single-cell transcriptomics (Lopez et al.).

Parameters:
  • var_names_g (Sequence[str]) – The variable names schema for the input data validation.

  • encoder (dict[str, list[dict] | dict | bool]) – Dict specifying the encoder configuration.

  • decoder (dict[str, list[dict] | dict | bool]) – Dict specifying the decoder configuration.

  • n_latent (int) – Dimension of the latent space.

  • n_batch (int) – Number of total batches in the dataset.

  • batch_representation_sampled (bool) – True to sample latent batch from a distribution.

  • n_continuous_cov (int) – Number of continuous covariates.

  • n_cats_per_cov (list[int] | None) – A list of integers containing the number of categories for each categorical covariate.

  • dropout_rate (float) – Dropout rate for hidden units in the encoder only.

  • input_gene_dropout_rate (float) – Gene dropout rate for input data that goes into the encoder during training.

  • dispersion (Literal['gene', 'gene-batch', 'gene-label', 'gene-cell']) –

    Flexibility of the dispersion parameter when gene_likelihood is either "nb" or "zinb". One of the following:

    • "gene": parameter is constant per gene across cells.

    • "gene-batch": parameter is constant per gene per batch.

    • "gene-label": parameter is constant per gene per label.

    • "gene-cell": parameter is constant per gene per cell.

  • log_variational (bool) – If True, use log1p() on input data before encoding for numerical stability (not normalization).

  • gene_likelihood (Literal['zinb', 'nb', 'poisson']) – Distribution to use for reconstruction in the generative process. One of the following: * "nb": NegativeBinomial. * "zinb": ZeroInflatedNegativeBinomial. (not implemented) * "poisson": Poisson.

  • latent_distribution (Literal['normal', 'ln']) – Distribution to use for the latent space. One of the following: * "normal": isotropic normal. * "ln": logistic normal with normal params N(0, 1). (not implemented)

  • use_batch_norm (Literal['encoder', 'decoder', 'none', 'both']) – Specifies where to use BatchNorm1d in the model. One of the following: * "none": don’t use batch norm in either encoder(s) or decoder. * "encoder": use batch norm only in the encoder(s). * "decoder": use batch norm only in the decoder. * "both": use batch norm in both encoder(s) and decoder.

  • use_layer_norm (Literal['encoder', 'decoder', 'none', 'both']) –

    Specifies where to use LayerNorm in the model. One of the following:
    • "none": don’t use layer norm in either encoder(s) or decoder.

    • "encoder": use layer norm only in the encoder(s).

    • "decoder": use layer norm only in the decoder.

    • "both": use layer norm in both encoder(s) and decoder.

    Note: only one of use_batch_norm or use_layer_norm should be specified.

  • use_size_factor_key (bool) – If True, use the obs column as defined by the size_factor_key parameter in the model’s setup_anndata method as the scaling factor in the mean of the conditional distribution. If False, the observed library size (log of the sum of counts per cell) is used. Should be False.

  • reconstruct_counts_on_predict (bool) – Changes the behavior of predict(). True will reconstruct gene expression count data, False will return the latent representations

  • reconstruction_var_names_g (ndarray | list | None) – List of var_names to be reconstructed (outputs are dense matrices)

  • reconstruction_transform_batch (None | int | str) – None will reconstruct in the original data batch. This is like imputation or smoothing. An integer will reconstruct counts in the specified batch index. The string “mean” will reconstruct counts in the first 10 batches and return the mean.

  • reconstruction_n_latent_samples (int) – Number of latent samples to use for reconstruction. Each latent sample will be used to compute the mean of the generative distribution, and the final output will be the mean of those.

  • reconstruction_use_latent_mean (bool) – True to use the mean of the latent distribution rather than sampling.

  • reconstruction_use_importance_sampling (bool) – True to use importance sampling weighted by each latent sample’s likelihood.

  • reconstructed_library_size (int) – The library size to use for the reconstruction, common to all cells.

  • use_flow (bool) – If True, use a Neural Spline Flow (NSF) as the prior on the latent space instead of the standard normal N(0, I). The flow is unconditional (batch-blind) and is jointly trained with the encoder/decoder via an MC-KL estimate: E_q[log q(z|x) - log p_flow(z)].

  • flow_hidden_features (list[int]) – Hidden layer widths for the NSF. Only used when use_flow=True.

  • cell_type_categories (list[str] | None) – Ordered list of CL ID strings (e.g. ["CL:0000540", ...]) that matches adata.obs[cell_type_col].cat.categories exactly (same order), so that integer codes from .cat.codes map directly to rows of the internal distance buffer. Required if ontology_distance_matrix is provided.

  • ontology_distance_matrix (DataFrame | None) – Square pandas.DataFrame with CL ID strings as both index and columns, as returned by compute_cl_distance_matrix(). The constructor will slice and reorder this to match cell_type_categories. Enables the frequency-weighted Spearman correlation metric (val_ontology_spearman) during validation. Not saved to checkpoints.

  • val_cell_type_classifier_reservoir_size (int) – Maximum number of cells to retain per split (train / test) for the logistic regression cell type classifier. Reservoir sampling is used so this bound is respected regardless of validation set size. Train cells are drawn from even-numbered validation batches; test cells from odd-numbered batches. Ignored when cell_type_categories is not provided. Default 50_000.

  • batch_embedded (bool)

  • n_latent_batch (int | None)

  • z_kl_weight_max (float)

  • batch_kl_weight_max (float)

  • kl_warmup_epochs (int | None)

  • kl_warmup_steps (int | None)

  • kl_annealing_start (float)

  • reconstruction_transform_categorical_covariates (list[int] | None)

batch_representation_from_batch_index(batch_index_n: Tensor, use_mean_though_sampling: bool = False) Tensor[source]

Compute a batch representation from batch indices.

If self.batch_embedded is False, the batch representation will be one-hot (like scvi-tools) If self.batch_embedded is True:

If self.batch_representation_sampled is True, the batch representation is sampled from a normal distribution If self.batch_representation_sampled is False, the batch representation is a point estimate

Parameters:
  • batch_index_n (Tensor)

  • use_mean_though_sampling (bool)

Return type:

Tensor

categorical_onehot_from_categorical_index(categorical_covariate_index_nd: Tensor | None) Tensor | None[source]

Compute one-hot encoding of categorical covariates from integer category indices.

Parameters:

categorical_covariate_index_nd (Tensor | None) – a tensor of shape (n, n_categorical_covariates)

Return type:

Tensor | None

inference(x_ng: Tensor, batch_nb: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_np: Tensor | None = None)[source]

High level inference method. Runs the inference (encoder) model.

Parameters:
  • x_ng (Tensor)

  • batch_nb (Tensor)

  • continuous_covariates_nc (Tensor | None)

  • categorical_covariate_np (Tensor | None)

generative(z_nk: Tensor, library_size_n1: Tensor, batch_nb: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_np: Tensor | None = None) dict[str, Distribution][source]

Runs the generative model.

Parameters:
  • z_nk (Tensor)

  • library_size_n1 (Tensor)

  • batch_nb (Tensor)

  • continuous_covariates_nc (Tensor | None)

  • categorical_covariate_np (Tensor | None)

Return type:

dict[str, Distribution]

forward(x_ng: Tensor, var_names_g: ndarray, batch_index_n: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_index_nd: Tensor | None = None, total_mrna_umis_n: Tensor | None = None)[source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

  • batch_index_n (Tensor) – Batch indices of input cells as integers.

  • continuous_covariates_nc (Tensor | None) – Continuous covariates for each cell (c-dimensional).

  • categorical_covariate_index_nd (Tensor | None) – Categorical covariates for each cell (d-dimensional). Integer membership categorical codes.

  • total_mrna_umis_n (Tensor | None) – Total mRNA UMIs for each cell (not log scaled) if this should be used.

Returns:

  • “loss”: The total loss value.

  • ”reconstruction_loss”: The reconstruction loss value.

  • ”kl_divergence_z”: The KL divergence for the latent variable z.

  • ”z_nk”: The latent variable z.

Return type:

A dictionary with keys

predict(x_ng: Tensor, var_names_g: ndarray, batch_index_n: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_index_nd: Tensor | None = None)[source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

  • batch_index_n (Tensor) – Batch indices of input cells as integers.

  • continuous_covariates_nc (Tensor | None) – Continuous covariates for each cell (c-dimensional).

  • categorical_covariate_index_nd (Tensor | None) – Categorical covariates for each cell (d-dimensional where d is number of categorical variables). Values are integer membership categorical codes.

Returns:

  • x_ng:
    • If self.reconstruct_counts_on_predict is False:
      • (x_ng is a notational misnomer) Embedding of the input data into the scVI latent space,

        typically referred to as z_nk.

    • If self.reconstruct_counts_on_predict is True:
      • (x_ng is a notational misnomer) Reconstruction of the input data: x_ng', which

        may have a different number of genes depending on self.reconstruction_var_names_g.

Return type:

A dictionary with the following keys

reconstruct(x_ng: Tensor, var_names_g: ndarray, gene_inds: ndarray | list[int], batch_index_n: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_index_nd: Tensor | None = None, transform_batch: str | int | None = None, transform_categorical_covariates: list[int] | None = None, use_latent_mean: bool = False, n_latent_samples: int = 1000, use_importance_sampling: bool = False, reconstructed_library_size: float = 10000)[source]

Reconstruct the data using the VAE, optionally transforming the batch.

Note: scvi-tools uses the following strategy -
  • for each transform_batch, put the data through the encoder (no dropout)

  • sample n_latent_samples times to get several z values

  • take the mean of the generative distribution

  • obtain tensor shape [n_transform_batches, n_latent_samples, n_cells, n_genes]

  • take a mean over the batches dimension

(- they optionally use importance sampling based on sampled z likelihoods) - take a (weighted?) mean over the n_latent_samples dimension

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

  • gene_inds (ndarray | list[int]) – The indices of the genes from var_names_g to be reconstructed. Output order preserves this order.

  • batch_index_n (Tensor) – Batch indices of input cells as integers.

  • continuous_covariates_nc (Tensor | None) – Continuous covariates for each cell (c-dimensional).

  • categorical_covariate_index_nd (Tensor | None) – Categorical covariates for each cell (d-dimensional where d is the number of categorical variables). Used for the encoder; also used for the decoder when transform_categorical_covariates is None.

  • transform_batch (str | int | None) – If not None, transform the batch to this index before reconstruction.

  • transform_categorical_covariates (list[int] | None) – A list of integer category indices, one per categorical covariate (in the same order as n_cats_per_cov), to fix for all cells during decoding. Must be supplied when transform_batch is not None and the decoder uses categorical covariates — use enumerate_observed_batch_covariate_combinations() to identify valid combinations present in the training data. If None, the per-cell observed categorical covariates are passed to the decoder.

  • use_latent_mean (bool) – If True, use the mean of the latent distribution instead of sampling.

  • n_latent_samples (int) – The number of latent samples to use for reconstruction.

  • use_importance_sampling (bool) – True to use importance sampling for the reconstruction, weighting each sample of the latent by its likelihood.

  • reconstructed_library_size (float) – The library size to use for the reconstruction, common to all cells.

Returns:

  • x_ng: Model’s reconstruction of the input data, possibly de-batched. The notational misnomer

    here is that “g” no longer stands for all genes, but the genes in gene_inds.

Return type:

A dictionary with the following keys

class cellarium.ml.models.SOCAM(n_obs: int, var_names_g: ndarray, descendant_tensor: Tensor, cl_names: list[str], cl_name_subset: list[str] | None = None, probability_propagation_flag: bool = True, W_prior_scale: float = 0.01, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True, include_ancestors_of_cl_name_subset: bool = True)[source]

Bases: CellariumModel, PredictMixin, ValidateMixin

Logistic regression model for cell type ontology classification.

Parameters:
  • n_obs (int) – Number of observations in the dataset (used to scale the cross-entropy loss).

  • var_names_g (ndarray) – The variable-name schema for the input data; used for validation.

  • output_categories – Total number of target categories expected at prediction/validation time. Used when the trained model has fewer categories than the final output space.

  • descendant_tensor (Tensor) – Binary (0/1) tensor of shape (n_categories, n_categories) defining the descendant relationships between categories. Row i contains ones for all categories considered descendants of category i (plus self). Used for probability-propagation.

  • cl_names (list[str]) – Full list of category identifiers matching the rows/columns of descendant_tensor.

  • cl_name_subset (list[str] | None) – Optional list of category names (from cl_names) to restrict training and prediction to. The list is sorted internally so order does not matter. When None, all categories are used.

  • probability_propagation_flag (bool) – If True, applies hierarchical probability propagation before predicting the output distribution.

  • W_prior_scale (float) – Scale (b) parameter of the Laplace prior on the weight matrix W_gc.

  • W_init_scale (float) – Standard deviation for initializing W_gc.

  • seed (int) – Random seed used to initialize parameters.

  • log_metrics (bool) – If True, logs weight histograms (TensorBoard) during training. If True, logs weight histograms (TensorBoard) during training.

  • include_ancestors_of_cl_name_subset (bool)

forward(x_ng: Tensor, var_names_g: ndarray, cl_names_n: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – The input data.

  • var_names_g (ndarray) – The variable names for the input data.

  • cl_names_n (ndarray) – Array of length n containing a category name string (from self.cl_names) for each cell. When self.cl_name_subset is set, every label must be a member of that subset.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor | None]

validate(trainer: Trainer, pl_module: LightningModule, batch_idx: int, x_ng: Tensor, var_names_g: ndarray, cl_names_n: ndarray) None[source]

Default validation method for models. This method logs the validation loss to TensorBoard. Override this method to customize the validation behavior.

Parameters:
  • trainer (Trainer)

  • pl_module (LightningModule)

  • batch_idx (int)

  • x_ng (Tensor)

  • var_names_g (ndarray)

  • cl_names_n (ndarray)

Return type:

None

predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]

Predict the target logits.

Parameters:
  • x_ng (Tensor) – The input data.

  • var_names_g (ndarray) – The variable names for the input data.

Returns:

A dictionary with the target logits. Output tensors have shape (n, n_active_cats).

Return type:

dict[str, ndarray | Tensor]