Models
- class cellarium.ml.models.CellariumModel[source]
Bases:
Module
Base class for Cellarium ML compatible models.
- abstract reset_parameters() None [source]
Reset the model parameters and buffers that were constructed in the __init__ method. Constructed means methods like
torch.tensor
,torch.empty
,torch.zeros
,torch.randn
,torch.as_tensor
, etc. If the parameter or buffer was assigned (e.g. from a torch.Tensor passed to the __init__) then there is no need to reset it.- Return type:
None
- class cellarium.ml.models.Geneformer(var_names_g: ndarray, hidden_size: int = 256, num_hidden_layers: int = 6, num_attention_heads: int = 4, intermediate_size: int = 512, hidden_act: str = 'relu', hidden_dropout_prob: float = 0.02, attention_probs_dropout_prob: float = 0.02, max_position_embeddings: int = 2048, type_vocab_size: int = 2, initializer_range: float = 0.02, position_embedding_type: str = 'absolute', layer_norm_eps: float = 1e-12, mlm_probability: float = 0.15)[source]
Bases:
CellariumModel
,PredictMixin
Geneformer model.
References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
hidden_size (int) – Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (int) – Number of hidden layers in the Transformer encoder.
num_attention_heads (int) – Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (int) – Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.
hidden_act (str) – The non-linear activation function (function or string) in the encoder and pooler. If string,
"gelu"
,"relu"
,"silu"
and"gelu_new"
are supported.hidden_dropout_prob (float) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (float) – The dropout ratio for the attention probabilities.
max_position_embeddings (int) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (int) – The vocabulary size of the
token_type_ids
passed when callingtransformers.BertModel
.initializer_range (float) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
position_embedding_type (str) – Type of position embedding. Choose one of
"absolute"
,"relative_key"
,"relative_key_query"
. For positional embeddings use"absolute"
. For more information on"relative_key"
, please refer to Self-Attention with Relative Position Representations (Shaw et al.). For more information on"relative_key_query"
, please refer to Method 4 in Improve Transformer Models with Better Relative Position Embeddings (Huang et al.).layer_norm_eps (float) – The epsilon used by the layer normalization layers.
mlm_probability (float) – Ratio of tokens to mask for masked language modeling loss.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor] [source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
A dictionary with the loss value.
- Return type:
dict[str, Tensor]
- predict(x_ng: Tensor, var_names_g: ndarray, output_hidden_states: bool = True, output_attentions: bool = True, output_input_ids: bool = True, output_attention_mask: bool = True, feature_activation: list[str] | None = None, feature_deletion: list[str] | None = None, feature_map: dict[str, int] | None = None) dict[str, Tensor | ndarray] [source]
Send (transformed) data through the model and return outputs. Optionally perform in silico perturbations and masking.
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
output_hidden_states (bool) – Whether to return all hidden-states.
output_attentions (bool) – Whether to return all attentions.
output_input_ids (bool) – Whether to return input ids.
output_attention_mask (bool) – Whether to return attention mask.
feature_activation (list[str] | None) – Specify features whose expression should be set to > max(x_ng) before tokenization (top rank).
feature_deletion (list[str] | None) – Specify features whose expression should be set to zero before tokenization (remove from inputs).
feature_map (dict[str, int] | None) – Specify a mapping for input tokens, to be applied before model.
- Returns:
A dictionary with the inference results.
- Return type:
dict[str, Tensor | ndarray]
Note
In silico perturbations can be achieved in one of three ways:
Use
feature_map
to replace a feature token withMASK
(1) orPAD
(0) (e.g.feature_map={"ENSG0001": 1}
will replacevar_names_g
featureENSG0001
with aMASK
token).Use
feature_deletion
to remove a feature from the cell’s inputs, which instead of adding aPAD
orMASK
token, will allow another feature to take its place (e.g.feature_deletion=["ENSG0001"]
will removevar_names_g
featureENSG0001
from the input, and allow a new feature token to take its place).Use
feature_activation
to move a feature all the way to the top rank position in the input (e.g.feature_activation=["ENSG0001"]
will makevar_names_g
featureENSG0001
the first in rank order. Multiple input features will be ranked according to their order in the input list).
Number (2) and (3) are described in the Geneformer paper under “In silico perturbation” in the Methods section.
- class cellarium.ml.models.IncrementalPCA(var_names_g: ndarray, n_components: int, svd_lowrank_niter: int = 2, perform_mean_correction: bool = True)[source]
Bases:
CellariumModel
,PredictMixin
Distributed and Incremental PCA.
References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
n_components (int) – Number of principal components.
svd_lowrank_niter (int) – Number of iterations for the low-rank SVD algorithm.
perform_mean_correction (bool) – If
True
then the mean correction is applied to the update step. IfFalse
then the data is assumed to be centered and the mean correction is not applied to the update step.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None] [source]
Incrementally update partial SVD with new data.
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
An empty dictionary.
- Return type:
dict[str, Tensor | None]
- on_train_epoch_end(trainer: Trainer) None [source]
Merge partial SVD results from parallel processes at the end of the epoch.
Merging SVDs is performed hierarchically. At each merging level, the leading process (even rank) merges its SVD with the trailing process (odd rank). The trailing process discards its SVD and is terminated. The leading process continues to the next level. This process continues until only one process remains. The final SVD is stored on the remaining process.
The number of levels (hierarchy depth) scales logarithmically with the number of processes.
- Parameters:
trainer (Trainer)
- Return type:
None
- property explained_variance_k: Tensor
The amount of variance explained by each of the selected components. The variance estimation uses
x_size
degrees of freedom.Equal to
n_components
largest eigenvalues of the covariance matrix of input data.
- property components_kg: Tensor
Principal axes in feature space, representing the directions of maximum variance in the data. Equivalently, the right singular vectors of the centered input data, parallel to its eigenvectors. The components are sorted by decreasing
explained_variance_k
.
- predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor] [source]
Centering and embedding of the input data
x_ng
into the principal component space.- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
x_ng
: Embedding of the input data into the principal component space.var_names_g
: The list of variable names for the output data.
- Return type:
A dictionary with the following keys
- class cellarium.ml.models.LogisticRegression(n_obs: int, var_names_g: ndarray, y_categories: ndarray, W_prior_scale: float = 1.0, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True)[source]
Bases:
CellariumModel
,PredictMixin
,ValidateMixin
Logistic regression model.
- Parameters:
n_obs (int) – Number of observations.
var_names_g (ndarray) – The variable names schema for the input data validation.
y_categories (ndarray) – The categories for the target data.
W_prior_scale (float) – The scale of the Laplace prior for the weights.
W_init_scale (float) – Initialization scale for the
W_gc
parameter.seed (int) – Random seed used to initialize parameters.
log_metrics (bool) – Whether to log the histogram of the
W_gc
parameter.
- forward(x_ng: Tensor, var_names_g: ndarray, y_n: Tensor, y_categories: ndarray) dict[str, Tensor | None] [source]
- Parameters:
x_ng (Tensor) – The input data.
var_names_g (ndarray) – The variable names for the input data.
y_n (Tensor) – The target data.
y_categories (ndarray) – The categories for the input target data.
- Returns:
A dictionary with the loss value.
- Return type:
dict[str, Tensor | None]
- predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor] [source]
Predict the target logits.
- Parameters:
x_ng (Tensor) – The input data.
var_names_g (ndarray) – The variable names for the input data.
- Returns:
A dictionary with the target logits.
- Return type:
dict[str, ndarray | Tensor]
- class cellarium.ml.models.OnePassMeanVarStd(var_names_g: ndarray, algorithm: Literal['naive', 'shifted_data'] = 'naive')[source]
Bases:
CellariumModel
Calculate the mean, variance, and standard deviation of the data in one pass (epoch) using running sums and running squared sums.
References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
algorithm (Literal['naive', 'shifted_data'])
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None] [source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
An empty dictionary.
- Return type:
dict[str, Tensor | None]
- property mean_g: Tensor
Mean of the data.
- property var_g: Tensor
Variance of the data.
- property std_g: Tensor
Standard deviation of the data.
- class cellarium.ml.models.PredictMixin[source]
Bases:
object
Abstract mixin class for models that can perform prediction.
- class cellarium.ml.models.ProbabilisticPCA(n_obs: int, var_names_g: ndarray, n_components: int, ppca_flavor: Literal['marginalized', 'linear_vae'], mean_g: Tensor | None = None, W_init_scale: float = 1.0, sigma_init_scale: float = 1.0, seed: int = 0)[source]
Bases:
CellariumModel
,PredictMixin
Probabilistic PCA implemented in Pyro.
Two flavors of probabilistic PCA are available - marginalized pPCA [1] and linear VAE [2].
References:
Probabilistic Principal Component Analysis (Tipping et al.).
Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.).
- Parameters:
n_obs (int) – Number of cells.
var_names_g (ndarray) – The variable names schema for the input data validation.
n_components (int) – Number of principal components.
ppca_flavor (Literal['marginalized', 'linear_vae']) – Type of the PPCA model. Has to be one of marginalized or linear_vae.
mean_g (Tensor | None) – Mean gene expression of the input data. If
None
then the mean is set to a learnable parameter.W_init_scale (float) – Scale of the random initialization of the W_kg parameter.
sigma_init_scale (float) – Initialization value of the sigma parameter.
seed (int) – Random seed used to initialize parameters.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None] [source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
A dictionary with the loss value.
- Return type:
dict[str, Tensor | None]
- predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor] [source]
Centering and embedding of the input data
x_ng
into the principal component space.Note
Gradients are disabled, used for inference only.
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
z_nk
: Embedding of the input data into the principal component space.
- Return type:
A dictionary with the following keys
- property L_k: Tensor
Vector with elements given by the PC eigenvalues.
Note
Gradients are disabled, used for inference only.
- property U_gk: Tensor
Principal components corresponding to eigenvalues
L_k
.Note
Gradients are disabled, used for inference only.
- property W_variance: float
Note
Gradients are disabled, used for inference only.
- property sigma_variance: float
Note
Gradients are disabled, used for inference only.
- class cellarium.ml.models.TDigest(var_names_g: ndarray)[source]
Bases:
CellariumModel
Compute an approximate non-zero histogram of the distribution of each gene in a batch of cells using t-digests.
References:
- Parameters:
var_names_g (ndarray) – The variable names schema for the input data validation.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None] [source]
- Parameters:
x_ng (Tensor) – Gene counts matrix.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
An empty dictionary.
- Return type:
dict[str, Tensor | None]
- property median_g: Tensor
Median of the data.
- class cellarium.ml.models.ValidateMixin[source]
Bases:
object
Mixin class for models that can perform validation.
- validate(trainer: Trainer, pl_module: LightningModule, batch_idx: int, *args: Any, **kwargs: Any) None [source]
Default validation method for models. This method logs the validation loss to TensorBoard. Override this method to customize the validation behavior.
- Parameters:
trainer (Trainer)
pl_module (LightningModule)
batch_idx (int)
args (Any)
kwargs (Any)
- Return type:
None