Models
- class cellarium.ml.models.CellariumGPT(categorical_token_size_dict: dict[str, int], d_model: int, d_ffn: int, n_heads: int, n_blocks: int, dropout_p: float, use_bias: bool, attention_backend: Literal['flex', 'math', 'mem_efficient', 'torch'], attention_softmax_fp32: bool, loss_scale_dict: dict[str, float], initializer_range: float = 0.02, embeddings_scale: float = 1.0, attention_logits_scale: float = 1.0, output_logits_scale: float = 1.0, mup_base_d_model: int | None = None, mup_base_d_ffn: int | None = None)[source]
Bases:
CellariumModel,PredictMixin,ValidateMixinCellariumGPT model.
- Parameters:
categorical_token_size_dict (dict[str, int]) – Categorical token vocabulary sizes. Must include “gene_value” and “gene_id”. Additionally, it can include experimental conditions, such as “assay” and “suspension_type”, and metadata tokens such as “cell_type”, “tissue”, “sex”, “development_stage”, and “disease”.
d_model (int) – Dimensionality of the embeddings and hidden states.
d_ffn (int) – Dimensionality of the inner feed-forward layers.
n_heads (int) – Number of attention heads.
n_blocks (int) – Number of transformer blocks.
dropout_p (float) – Dropout probability.
use_bias (bool) – Whether to use bias in the linear transformations.
attention_backend (Literal['flex', 'math', 'mem_efficient', 'torch']) – Backend for the attention computation.
attention_softmax_fp32 (bool) – Whether to use float32 for softmax computation when
torchbackend is used.loss_scale_dict (dict[str, float]) – A dictionary of loss scales for each label type. These are the query tokens that are used to compute the loss.
initializer_range (float) – The standard deviation of the truncated normal initializer.
embeddings_scale (float) – Multiplier for the embeddings.
attention_logits_scale (float) – Multiplier for the attention logits.
output_logits_scale (float) – Multiplier for the output logits.
mup_base_d_model (int | None) – Base dimensionality of the model for muP.
mup_base_d_ffn (int | None) – Base dimensionality of the inner feed-forward layers for muP.
- predict(token_value_nc_dict: dict[str, Tensor], token_mask_nc_dict: dict[str, Tensor], prompt_mask_nc: Tensor) dict[str, ndarray | Tensor][source]
- Parameters:
token_value_nc_dict (dict[str, Tensor]) – Dictionary of token value tensors of shape
(n, c).token_mask_nc_dict (dict[str, Tensor]) – Dictionary of token mask tensors of shape
(n, c).prompt_mask_nc (Tensor)
- Returns:
Dictionary of logits tensors of shape
(n, c, k).- Return type:
dict[str, ndarray | Tensor]
- class cellarium.ml.models.CellariumModel[source]
Bases:
ModuleBase class for Cellarium ML compatible models.
- abstractmethod reset_parameters() None[source]
Reset the model parameters and buffers that were constructed in the __init__ method. Constructed means methods like
torch.tensor,torch.empty,torch.zeros,torch.randn,torch.as_tensor, etc. If the parameter or buffer was assigned (e.g. from a torch.Tensor passed to the __init__) then there is no need to reset it.- Return type:
None
- class cellarium.ml.models.ContrastiveMLP(n_obs: int, hidden_size: Sequence[int], embed_dim: int, temperature: float = 1.0)[source]
Bases:
CellariumModel,PredictMixinMultilayer perceptron trained with contrastive learning.
- Parameters:
n_obs (int) – Number of observations in each entry (network input size).
hidden_size (Sequence[int]) – Dimensionality of the fully-connected hidden layers.
embed_dim (int) – Size of embedding (network output size).
temperature (float) – Parameter governing Normalized Temperature-scaled cross-entropy (NT-Xent) loss.
- class cellarium.ml.models.Geneformer(var_names_g: ndarray, hidden_size: int = 256, num_hidden_layers: int = 6, num_attention_heads: int = 4, intermediate_size: int = 512, hidden_act: str = 'relu', hidden_dropout_prob: float = 0.02, attention_probs_dropout_prob: float = 0.02, max_position_embeddings: int = 2048, type_vocab_size: int = 2, initializer_range: float = 0.02, position_embedding_type: str = 'absolute', layer_norm_eps: float = 1e-12, mlm_probability: float = 0.15)[source]
Bases:
CellariumModel,PredictMixinGeneformer model.
References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
hidden_size (int) – Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (int) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int) – Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (int) – Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.
hidden_act (str) – The non-linear activation function (function or string) in the encoder and pooler. If string,
"gelu","relu","silu"and"gelu_new"are supported.hidden_dropout_prob (float) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (float) – The dropout ratio for the attention probabilities.
max_position_embeddings (int) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (int) – The vocabulary size of the
token_type_idspassed when callingtransformers.BertModel.initializer_range (float) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
position_embedding_type (str) – Type of position embedding. Choose one of
"absolute","relative_key","relative_key_query". For positional embeddings use"absolute". For more information on"relative_key", please refer to Self-Attention with Relative Position Representations (Shaw et al.). For more information on"relative_key_query", please refer to Method 4 in Improve Transformer Models with Better Relative Position Embeddings (Huang et al.).layer_norm_eps (float) – The epsilon used by the layer normalization layers.
mlm_probability (float) – Ratio of tokens to mask for masked language modeling loss.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor][source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
A dictionary with the loss value.
- Return type:
dict[str, Tensor]
- predict(x_ng: Tensor, var_names_g: ndarray, output_hidden_states: bool = True, output_attentions: bool = True, output_input_ids: bool = True, output_attention_mask: bool = True, feature_activation: list[str] | None = None, feature_deletion: list[str] | None = None, feature_map: dict[str, int] | None = None) dict[str, Tensor | ndarray][source]
Send (transformed) data through the model and return outputs. Optionally perform in silico perturbations and masking.
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
output_hidden_states (bool) – Whether to return all hidden-states.
output_attentions (bool) – Whether to return all attentions.
output_input_ids (bool) – Whether to return input ids.
output_attention_mask (bool) – Whether to return attention mask.
feature_activation (list[str] | None) – Specify features whose expression should be set to > max(x_ng) before tokenization (top rank).
feature_deletion (list[str] | None) – Specify features whose expression should be set to zero before tokenization (remove from inputs).
feature_map (dict[str, int] | None) – Specify a mapping for input tokens, to be applied before model.
- Returns:
A dictionary with the inference results.
- Return type:
dict[str, Tensor | ndarray]
Note
In silico perturbations can be achieved in one of three ways:
Use
feature_mapto replace a feature token withMASK(1) orPAD(0) (e.g.feature_map={"ENSG0001": 1}will replacevar_names_gfeatureENSG0001with aMASKtoken).Use
feature_deletionto remove a feature from the cell’s inputs, which instead of adding aPADorMASKtoken, will allow another feature to take its place (e.g.feature_deletion=["ENSG0001"]will removevar_names_gfeatureENSG0001from the input, and allow a new feature token to take its place).Use
feature_activationto move a feature all the way to the top rank position in the input (e.g.feature_activation=["ENSG0001"]will makevar_names_gfeatureENSG0001the first in rank order. Multiple input features will be ranked according to their order in the input list).
Number (2) and (3) are described in the Geneformer paper under “In silico perturbation” in the Methods section.
- class cellarium.ml.models.HVGSeuratV3(var_names_g: ndarray, n_top_genes: int | list[int], n_batch: int = 1, flavor: Literal['seurat_v3', 'seurat_v3_paper'] = 'seurat_v3', use_batch_key: bool = False, span: float = 0.3, output_path: str | None = 'hvg_seurat_v3_output.csv', batch_n_cell_minimum: int = 2, n_batch_minimum: int = 1)[source]
Bases:
CellariumModelCompute highly variable genes using the Seurat v3 method in two Lightning epochs.
Epoch 0 — streams data to accumulate per-batch mean and variance.
Between epochs — fits a LOESS model of
log10(var) ~ log10(mean)per batch to estimate a regularized standard deviation and per-cell clip value.Epoch 1 — streams data again, clips counts per cell at the batch-level
clip_val, and accumulates clipped sums.After epoch 1 (
on_train_epoch_end) — computes normalized variance per batch, ranks genes, and writesself.hvg_dfs(and optionally a CSV file atoutput_path).The
flavorargument controls how genes are ranked across batches whenn_batch > 1, matching the two multi-batch modes in Scanpy:"seurat_v3"(default) — sorts byhighly_variable_rankascending first, then byhighly_variable_nbatchesdescending as tiebreaker. Genes with the lowest median rank across batches are preferred. This matches Scanpy’sflavor='seurat_v3'."seurat_v3_paper"— sorts byhighly_variable_nbatchesdescending first, then byhighly_variable_rankascending as tiebreaker. Genes that are highly variable in more batches are preferred, regardless of their exact rank. This matches Scanpy’sflavor='seurat_v3_paper'.
Usage:
model = HVGSeuratV3(var_names_g=gene_names, n_top_genes=2000, n_batch=4, flavor="seurat_v3_paper", span=0.3) trainer = pl.Trainer(max_epochs=2) trainer.fit(module, datamodule) df = model.hvg_df # pandas DataFrame, Scanpy-compatible columns
- Parameters:
var_names_g (np.ndarray) – Array of gene names, length
n_genes.n_top_genes (int | list[int]) – Number of highly variable genes to select. Can be a list of ints to produce multiple gene sets in a single training run.
n_batch (int) – Number of batches (use 1 when no batch information is given).
flavor (Literal['seurat_v3', 'seurat_v3_paper']) – Multi-batch gene ranking strategy.
"seurat_v3_paper"(default) prioritises batch consistency;"seurat_v3"prioritises median rank.use_batch_key (bool) – Whether to expect a
batch_index_nbatch key in the dataloader output. If False, the model will ignore batch keys and treat all data as a single batch.span (float) – LOESS span (fraction of data used per local fit). Default 0.3.
output_path (str | None) – If given, the result DataFrame is written to this CSV filepath after training. (Ends with
.csv). IfNone, no file is written, but you really want to write this output, as it would require manually calling _compute_hvg_df() later.batch_n_cell_minimum (int) – Minimum number of cells required for a batch to be considered valid.
n_batch_minimum (int) – Minimum number of batches in which the gene is in n_top_genes highly variable for a gene to make the final list.
- on_train_epoch_start(trainer: Trainer) None[source]
Cache the current epoch number so forward() can branch on it.
- Parameters:
trainer (Trainer)
- Return type:
None
- on_train_start(trainer: Trainer) None[source]
Validation: if in a distributed setting, use DDP with broadcast_buffers=False.
- Parameters:
trainer (Trainer)
- Return type:
None
- on_train_epoch_end(trainer: Trainer) None[source]
Implement the two-step Seurat v3 HVG method using hooks at the end of first and second epochs. After epoch 0, reduce buffers and fit LOESS to set clip_val_bg and reg_std_bg. After epoch 1, reduce buffers and compute hvg_df.
- Parameters:
trainer (Trainer)
- Return type:
None
- class cellarium.ml.models.IncrementalPCA(var_names_g: ndarray, n_components: int, svd_lowrank_niter: int = 2, perform_mean_correction: bool = True)[source]
Bases:
CellariumModel,PredictMixinDistributed and Incremental PCA.
References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
n_components (int) – Number of principal components.
svd_lowrank_niter (int) – Number of iterations for the low-rank SVD algorithm.
perform_mean_correction (bool) – If
Truethen the mean correction is applied to the update step. IfFalsethen the data is assumed to be centered and the mean correction is not applied to the update step.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
Incrementally update partial SVD with new data.
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
An empty dictionary.
- Return type:
dict[str, Tensor | None]
- on_train_epoch_end(trainer: Trainer) None[source]
Merge partial SVD results from parallel processes at the end of the epoch.
Merging SVDs is performed hierarchically. At each merging level, the leading process (even rank) merges its SVD with the trailing process (odd rank). The trailing process discards its SVD and is terminated. The leading process continues to the next level. This process continues until only one process remains. The final SVD is stored on the remaining process.
The number of levels (hierarchy depth) scales logarithmically with the number of processes.
- Parameters:
trainer (Trainer)
- Return type:
None
- property explained_variance_k: Tensor
The amount of variance explained by each of the selected components. The variance estimation uses
x_sizedegrees of freedom.Equal to
n_componentslargest eigenvalues of the covariance matrix of input data.
- property components_kg: Tensor
Principal axes in feature space, representing the directions of maximum variance in the data. Equivalently, the right singular vectors of the centered input data, parallel to its eigenvectors. The components are sorted by decreasing
explained_variance_k.
- predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]
Centering and embedding of the input data
x_nginto the principal component space.- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
x_ng: Embedding of the input data into the principal component space.var_names_g: The list of variable names for the output data.
- Return type:
A dictionary with the following keys
- class cellarium.ml.models.LogisticRegression(n_obs: int, var_names_g: ndarray, y_categories: ndarray, W_prior_scale: float = 1.0, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True)[source]
Bases:
CellariumModel,PredictMixin,ValidateMixinLogistic regression model.
- Parameters:
n_obs (int) – Number of observations.
var_names_g (ndarray) – The variable names schema for the input data validation.
y_categories (ndarray) – The categories for the target data.
W_prior_scale (float) – The scale of the Laplace prior for the weights.
W_init_scale (float) – Initialization scale for the
W_gcparameter.seed (int) – Random seed used to initialize parameters.
log_metrics (bool) – Whether to log the histogram of the
W_gcparameter.
- forward(x_ng: Tensor, var_names_g: ndarray, y_n: Tensor, y_categories: ndarray) dict[str, Tensor | None][source]
- Parameters:
x_ng (Tensor) – The input data.
var_names_g (ndarray) – The variable names for the input data.
y_n (Tensor) – The target data.
y_categories (ndarray) – The categories for the input target data.
- Returns:
A dictionary with the loss value.
- Return type:
dict[str, Tensor | None]
- predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]
Predict the target logits.
- Parameters:
x_ng (Tensor) – The input data.
var_names_g (ndarray) – The variable names for the input data.
- Returns:
A dictionary with the target logits.
- Return type:
dict[str, ndarray | Tensor]
- class cellarium.ml.models.OnePassMeanVarStd(var_names_g: ndarray, algorithm: Literal['naive', 'shifted_data'] = 'naive', n_batch: int = 1, output_path: str | None = 'onepass_mean_var_std_output.csv')[source]
Bases:
CellariumModelCalculate the mean, variance, and standard deviation of the data in one pass (epoch) using running sums and running squared sums.
Tracks per-batch statistics. Use
n_batch=1when there is no meaningful batch structure. After training,batch_mean_bgandbatch_var_bggive per-batch per-gene statistics suitable for passing toget_highly_variable_genes.References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
algorithm (Literal['naive', 'shifted_data']) –
"naive"(default) or"shifted_data"(numerically stable).n_batch (int) – Number of batches. Use 1 to reproduce ``batch_key``=None behavior.
output_path (str | None) – Optional path to save a summary CSV at the end of training, containing the mean and variance for each gene. If None, no CSV will be saved.
- forward(x_ng: Tensor, var_names_g: ndarray, batch_index_n: Tensor | None = None) dict[str, Tensor | None][source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – Variable names in the input data.
batch_index_n (Tensor | None) – Optional batch indices for each cell, required if
n_batch> 1.
- Returns:
An empty dictionary.
- Return type:
dict[str, Tensor | None]
- property batch_mean_bg: Tensor
Per-batch mean, shape
(n_batch, n_genes).
- property batch_var_bg: Tensor
Per-batch population variance, shape
(n_batch, n_genes).
- class cellarium.ml.models.PredictMixin[source]
Bases:
objectAbstract mixin class for models that can perform prediction.
- class cellarium.ml.models.ProbabilisticPCA(n_obs: int, var_names_g: ndarray, n_components: int, ppca_flavor: Literal['marginalized', 'linear_vae'], mean_g: Tensor | None = None, W_init_scale: float = 1.0, sigma_init_scale: float = 1.0, seed: int = 0)[source]
Bases:
CellariumModel,PredictMixinProbabilistic PCA implemented in Pyro.
Two flavors of probabilistic PCA are available - marginalized pPCA [1] and linear VAE [2].
References:
Probabilistic Principal Component Analysis (Tipping et al.).
Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.).
- Parameters:
n_obs (int) – Number of cells.
var_names_g (ndarray) – The variable names schema for the input data validation.
n_components (int) – Number of principal components.
ppca_flavor (Literal['marginalized', 'linear_vae']) – Type of the PPCA model. Has to be one of marginalized or linear_vae.
mean_g (Tensor | None) – Mean gene expression of the input data. If
Nonethen the mean is set to a learnable parameter.W_init_scale (float) – Scale of the random initialization of the W_kg parameter.
sigma_init_scale (float) – Initialization value of the sigma parameter.
seed (int) – Random seed used to initialize parameters.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
A dictionary with the loss value.
- Return type:
dict[str, Tensor | None]
- predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]
Centering and embedding of the input data
x_nginto the principal component space.Note
Gradients are disabled, used for inference only.
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
z_nk: Embedding of the input data into the principal component space.
- Return type:
A dictionary with the following keys
- property L_k: Tensor
Vector with elements given by the PC eigenvalues.
Note
Gradients are disabled, used for inference only.
- property U_gk: Tensor
Principal components corresponding to eigenvalues
L_k.Note
Gradients are disabled, used for inference only.
- property W_variance: float
Note
Gradients are disabled, used for inference only.
- property sigma_variance: float
Note
Gradients are disabled, used for inference only.
- class cellarium.ml.models.TDigest(var_names_g: ndarray)[source]
Bases:
CellariumModelCompute an approximate non-zero histogram of the distribution of each gene in a batch of cells using t-digests.
References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
An empty dictionary.
- Return type:
dict[str, Tensor | None]
- property median_g: Tensor
Median of the data.
- class cellarium.ml.models.TestMixin[source]
Bases:
objectAbstract mixin class for models that can perform testing.
- class cellarium.ml.models.ValidateMixin[source]
Bases:
objectMixin class for models that can perform validation.
- validate(trainer: Trainer, pl_module: LightningModule, batch_idx: int, *args: Any, **kwargs: Any) None[source]
Default validation method for models. This method logs the validation loss to TensorBoard. Override this method to customize the validation behavior.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
batch_idx (int)
args (Any)
kwargs (Any)
- Return type:
None
- class cellarium.ml.models.SingleCellVariationalInference(var_names_g: Sequence[str], encoder: dict[str, list[dict] | dict | bool], decoder: dict[str, list[dict] | dict | bool], n_batch: int = 0, n_latent: int = 10, n_continuous_cov: int = 0, n_cats_per_cov: list[int] | None = None, dropout_rate: float = 0.1, dispersion: Literal['gene', 'gene-batch', 'gene-label', 'gene-cell'] = 'gene', log_variational: bool = True, gene_likelihood: Literal['zinb', 'nb', 'poisson'] = 'nb', latent_distribution: Literal['normal', 'ln'] = 'normal', batch_embedded: bool = False, batch_representation_sampled: bool = False, n_latent_batch: int | None = None, z_kl_weight_max: float = 1.0, batch_kl_weight_max: float = 0.0, input_gene_dropout_rate: float = 0.0, use_batch_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'both', use_layer_norm: Literal['encoder', 'decoder', 'none', 'both'] = 'none', kl_warmup_epochs: int | None = 400, kl_warmup_steps: int | None = None, kl_annealing_start: float = 0.0, use_size_factor_key: bool = False, reconstruct_counts_on_predict: bool = False, reconstruction_var_names_g: ndarray | list | None = None, reconstruction_transform_batch: None | int | str = 0, reconstruction_n_latent_samples: int = 30, reconstruction_use_latent_mean: bool = False, reconstruction_use_importance_sampling: bool = False, reconstructed_library_size: int = 10000, reconstruction_transform_categorical_covariates: list[int] | None = None, use_flow: bool = False, flow_hidden_features: list[int] = [64, 64], cell_type_categories: list[str] | None = None, ontology_distance_matrix: DataFrame | None = None, val_cell_type_classifier_reservoir_size: int = 50000)[source]
Bases:
CellariumModel,PredictMixin,ValidateMixinFlexible version of single-cell variational inference (scVI) [1] re-implemented in Cellarium ML.
References:
- Parameters:
var_names_g (Sequence[str]) – The variable names schema for the input data validation.
encoder (dict[str, list[dict] | dict | bool]) – Dict specifying the encoder configuration.
decoder (dict[str, list[dict] | dict | bool]) – Dict specifying the decoder configuration.
n_latent (int) – Dimension of the latent space.
n_batch (int) – Number of total batches in the dataset.
batch_representation_sampled (bool) – True to sample latent batch from a distribution.
n_continuous_cov (int) – Number of continuous covariates.
n_cats_per_cov (list[int] | None) – A list of integers containing the number of categories for each categorical covariate.
dropout_rate (float) – Dropout rate for hidden units in the encoder only.
input_gene_dropout_rate (float) – Gene dropout rate for input data that goes into the encoder during training.
dispersion (Literal['gene', 'gene-batch', 'gene-label', 'gene-cell']) –
Flexibility of the dispersion parameter when
gene_likelihoodis either"nb"or"zinb". One of the following:"gene": parameter is constant per gene across cells."gene-batch": parameter is constant per gene per batch."gene-label": parameter is constant per gene per label."gene-cell": parameter is constant per gene per cell.
log_variational (bool) – If
True, uselog1p()on input data before encoding for numerical stability (not normalization).gene_likelihood (Literal['zinb', 'nb', 'poisson']) – Distribution to use for reconstruction in the generative process. One of the following: *
"nb":NegativeBinomial. *"zinb":ZeroInflatedNegativeBinomial. (not implemented) *"poisson":Poisson.latent_distribution (Literal['normal', 'ln']) – Distribution to use for the latent space. One of the following: *
"normal": isotropic normal. *"ln": logistic normal with normal params N(0, 1). (not implemented)use_batch_norm (Literal['encoder', 'decoder', 'none', 'both']) – Specifies where to use
BatchNorm1din the model. One of the following: *"none": don’t use batch norm in either encoder(s) or decoder. *"encoder": use batch norm only in the encoder(s). *"decoder": use batch norm only in the decoder. *"both": use batch norm in both encoder(s) and decoder.use_layer_norm (Literal['encoder', 'decoder', 'none', 'both']) –
- Specifies where to use
LayerNormin the model. One of the following: "none": don’t use layer norm in either encoder(s) or decoder."encoder": use layer norm only in the encoder(s)."decoder": use layer norm only in the decoder."both": use layer norm in both encoder(s) and decoder.
Note: only one of use_batch_norm or use_layer_norm should be specified.
- Specifies where to use
use_size_factor_key (bool) – If
True, use theobscolumn as defined by thesize_factor_keyparameter in the model’ssetup_anndatamethod as the scaling factor in the mean of the conditional distribution. IfFalse, the observed library size (log of the sum of counts per cell) is used. Should beFalse.reconstruct_counts_on_predict (bool) – Changes the behavior of
predict(). True will reconstruct gene expression count data, False will return the latent representationsreconstruction_var_names_g (ndarray | list | None) – List of var_names to be reconstructed (outputs are dense matrices)
reconstruction_transform_batch (None | int | str) – None will reconstruct in the original data batch. This is like imputation or smoothing. An integer will reconstruct counts in the specified batch index. The string “mean” will reconstruct counts in the first 10 batches and return the mean.
reconstruction_n_latent_samples (int) – Number of latent samples to use for reconstruction. Each latent sample will be used to compute the mean of the generative distribution, and the final output will be the mean of those.
reconstruction_use_latent_mean (bool) – True to use the mean of the latent distribution rather than sampling.
reconstruction_use_importance_sampling (bool) – True to use importance sampling weighted by each latent sample’s likelihood.
reconstructed_library_size (int) – The library size to use for the reconstruction, common to all cells.
use_flow (bool) – If True, use a Neural Spline Flow (NSF) as the prior on the latent space instead of the standard normal N(0, I). The flow is unconditional (batch-blind) and is jointly trained with the encoder/decoder via an MC-KL estimate: E_q[log q(z|x) - log p_flow(z)].
flow_hidden_features (list[int]) – Hidden layer widths for the NSF. Only used when
use_flow=True.cell_type_categories (list[str] | None) – Ordered list of CL ID strings (e.g.
["CL:0000540", ...]) that matchesadata.obs[cell_type_col].cat.categoriesexactly (same order), so that integer codes from.cat.codesmap directly to rows of the internal distance buffer. Required ifontology_distance_matrixis provided.ontology_distance_matrix (DataFrame | None) – Square
pandas.DataFramewith CL ID strings as both index and columns, as returned bycompute_cl_distance_matrix(). The constructor will slice and reorder this to matchcell_type_categories. Enables the frequency-weighted Spearman correlation metric (val_ontology_spearman) during validation. Not saved to checkpoints.val_cell_type_classifier_reservoir_size (int) – Maximum number of cells to retain per split (train / test) for the logistic regression cell type classifier. Reservoir sampling is used so this bound is respected regardless of validation set size. Train cells are drawn from even-numbered validation batches; test cells from odd-numbered batches. Ignored when
cell_type_categoriesis not provided. Default 50_000.batch_embedded (bool)
n_latent_batch (int | None)
z_kl_weight_max (float)
batch_kl_weight_max (float)
kl_warmup_epochs (int | None)
kl_warmup_steps (int | None)
kl_annealing_start (float)
reconstruction_transform_categorical_covariates (list[int] | None)
- batch_representation_from_batch_index(batch_index_n: Tensor, use_mean_though_sampling: bool = False) Tensor[source]
Compute a batch representation from batch indices.
If self.batch_embedded is False, the batch representation will be one-hot (like scvi-tools) If self.batch_embedded is True:
If self.batch_representation_sampled is True, the batch representation is sampled from a normal distribution If self.batch_representation_sampled is False, the batch representation is a point estimate
- Parameters:
batch_index_n (Tensor)
use_mean_though_sampling (bool)
- Return type:
Tensor
- categorical_onehot_from_categorical_index(categorical_covariate_index_nd: Tensor | None) Tensor | None[source]
Compute one-hot encoding of categorical covariates from integer category indices.
- Parameters:
categorical_covariate_index_nd (Tensor | None) – a tensor of shape (n, n_categorical_covariates)
- Return type:
Tensor | None
- inference(x_ng: Tensor, batch_nb: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_np: Tensor | None = None)[source]
High level inference method. Runs the inference (encoder) model.
- Parameters:
x_ng (Tensor)
batch_nb (Tensor)
continuous_covariates_nc (Tensor | None)
categorical_covariate_np (Tensor | None)
- generative(z_nk: Tensor, library_size_n1: Tensor, batch_nb: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_np: Tensor | None = None) dict[str, Distribution][source]
Runs the generative model.
- Parameters:
z_nk (Tensor)
library_size_n1 (Tensor)
batch_nb (Tensor)
continuous_covariates_nc (Tensor | None)
categorical_covariate_np (Tensor | None)
- Return type:
dict[str, Distribution]
- forward(x_ng: Tensor, var_names_g: ndarray, batch_index_n: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_index_nd: Tensor | None = None, total_mrna_umis_n: Tensor | None = None)[source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
batch_index_n (Tensor) – Batch indices of input cells as integers.
continuous_covariates_nc (Tensor | None) – Continuous covariates for each cell (c-dimensional).
categorical_covariate_index_nd (Tensor | None) – Categorical covariates for each cell (d-dimensional). Integer membership categorical codes.
total_mrna_umis_n (Tensor | None) – Total mRNA UMIs for each cell (not log scaled) if this should be used.
- Returns:
“loss”: The total loss value.
”reconstruction_loss”: The reconstruction loss value.
”kl_divergence_z”: The KL divergence for the latent variable z.
”z_nk”: The latent variable z.
- Return type:
A dictionary with keys
- predict(x_ng: Tensor, var_names_g: ndarray, batch_index_n: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_index_nd: Tensor | None = None)[source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
batch_index_n (Tensor) – Batch indices of input cells as integers.
continuous_covariates_nc (Tensor | None) – Continuous covariates for each cell (c-dimensional).
categorical_covariate_index_nd (Tensor | None) – Categorical covariates for each cell (d-dimensional where d is number of categorical variables). Values are integer membership categorical codes.
- Returns:
x_ng:- If
self.reconstruct_counts_on_predictis False: - (x_ng is a notational misnomer) Embedding of the input data into the scVI latent space,
typically referred to as
z_nk.
- If
- If
self.reconstruct_counts_on_predictis True: - (x_ng is a notational misnomer) Reconstruction of the input data:
x_ng', which may have a different number of genes depending on
self.reconstruction_var_names_g.
- (x_ng is a notational misnomer) Reconstruction of the input data:
- If
- Return type:
A dictionary with the following keys
- reconstruct(x_ng: Tensor, var_names_g: ndarray, gene_inds: ndarray | list[int], batch_index_n: Tensor, continuous_covariates_nc: Tensor | None = None, categorical_covariate_index_nd: Tensor | None = None, transform_batch: str | int | None = None, transform_categorical_covariates: list[int] | None = None, use_latent_mean: bool = False, n_latent_samples: int = 1000, use_importance_sampling: bool = False, reconstructed_library_size: float = 10000)[source]
Reconstruct the data using the VAE, optionally transforming the batch.
- Note: scvi-tools uses the following strategy -
for each transform_batch, put the data through the encoder (no dropout)
sample n_latent_samples times to get several z values
take the mean of the generative distribution
obtain tensor shape [n_transform_batches, n_latent_samples, n_cells, n_genes]
take a mean over the batches dimension
(- they optionally use importance sampling based on sampled z likelihoods) - take a (weighted?) mean over the n_latent_samples dimension
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
gene_inds (ndarray | list[int]) – The indices of the genes from var_names_g to be reconstructed. Output order preserves this order.
batch_index_n (Tensor) – Batch indices of input cells as integers.
continuous_covariates_nc (Tensor | None) – Continuous covariates for each cell (c-dimensional).
categorical_covariate_index_nd (Tensor | None) – Categorical covariates for each cell (d-dimensional where d is the number of categorical variables). Used for the encoder; also used for the decoder when transform_categorical_covariates is None.
transform_batch (str | int | None) – If not None, transform the batch to this index before reconstruction.
transform_categorical_covariates (list[int] | None) – A list of integer category indices, one per categorical covariate (in the same order as n_cats_per_cov), to fix for all cells during decoding. Must be supplied when transform_batch is not None and the decoder uses categorical covariates — use enumerate_observed_batch_covariate_combinations() to identify valid combinations present in the training data. If None, the per-cell observed categorical covariates are passed to the decoder.
use_latent_mean (bool) – If True, use the mean of the latent distribution instead of sampling.
n_latent_samples (int) – The number of latent samples to use for reconstruction.
use_importance_sampling (bool) – True to use importance sampling for the reconstruction, weighting each sample of the latent by its likelihood.
reconstructed_library_size (float) – The library size to use for the reconstruction, common to all cells.
- Returns:
x_ng: Model’s reconstruction of the input data, possibly de-batched. The notational misnomerhere is that “g” no longer stands for all genes, but the genes in gene_inds.
- Return type:
A dictionary with the following keys
- class cellarium.ml.models.SOCAM(n_obs: int, var_names_g: ndarray, descendant_tensor: Tensor, cl_names: list[str], cl_name_subset: list[str] | None = None, probability_propagation_flag: bool = True, W_prior_scale: float = 0.01, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True, include_ancestors_of_cl_name_subset: bool = True)[source]
Bases:
CellariumModel,PredictMixin,ValidateMixinLogistic regression model for cell type ontology classification.
- Parameters:
n_obs (int) – Number of observations in the dataset (used to scale the cross-entropy loss).
var_names_g (ndarray) – The variable-name schema for the input data; used for validation.
output_categories – Total number of target categories expected at prediction/validation time. Used when the trained model has fewer categories than the final output space.
descendant_tensor (Tensor) – Binary (0/1) tensor of shape (n_categories, n_categories) defining the descendant relationships between categories. Row i contains ones for all categories considered descendants of category i (plus self). Used for probability-propagation.
cl_names (list[str]) – Full list of category identifiers matching the rows/columns of
descendant_tensor.cl_name_subset (list[str] | None) – Optional list of category names (from
cl_names) to restrict training and prediction to. The list is sorted internally so order does not matter. WhenNone, all categories are used.probability_propagation_flag (bool) – If True, applies hierarchical probability propagation before predicting the output distribution.
W_prior_scale (float) – Scale (b) parameter of the Laplace prior on the weight matrix W_gc.
W_init_scale (float) – Standard deviation for initializing W_gc.
seed (int) – Random seed used to initialize parameters.
log_metrics (bool) – If True, logs weight histograms (TensorBoard) during training. If True, logs weight histograms (TensorBoard) during training.
include_ancestors_of_cl_name_subset (bool)
- forward(x_ng: Tensor, var_names_g: ndarray, cl_names_n: ndarray) dict[str, Tensor | None][source]
- Parameters:
x_ng (Tensor) – The input data.
var_names_g (ndarray) – The variable names for the input data.
cl_names_n (ndarray) – Array of length n containing a category name string (from
self.cl_names) for each cell. Whenself.cl_name_subsetis set, every label must be a member of that subset.
- Returns:
A dictionary with the loss value.
- Return type:
dict[str, Tensor | None]
- validate(trainer: Trainer, pl_module: LightningModule, batch_idx: int, x_ng: Tensor, var_names_g: ndarray, cl_names_n: ndarray) None[source]
Default validation method for models. This method logs the validation loss to TensorBoard. Override this method to customize the validation behavior.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
batch_idx (int)
x_ng (Tensor)
var_names_g (ndarray)
cl_names_n (ndarray)
- Return type:
None
- predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]
Predict the target logits.
- Parameters:
x_ng (Tensor) – The input data.
var_names_g (ndarray) – The variable names for the input data.
- Returns:
A dictionary with the target logits. Output tensors have shape
(n, n_active_cats).- Return type:
dict[str, ndarray | Tensor]