Models

class cellarium.ml.models.CellariumModel[source]

Bases: Module

Base class for Cellarium ML compatible models.

abstract reset_parameters() None[source]

Reset the model parameters and buffers that were constructed in the __init__ method. Constructed means methods like torch.tensor, torch.empty, torch.zeros, torch.randn, torch.as_tensor, etc. If the parameter or buffer was assigned (e.g. from a torch.Tensor passed to the __init__) then there is no need to reset it.

Return type:

None

class cellarium.ml.models.Geneformer(var_names_g: ndarray, hidden_size: int = 256, num_hidden_layers: int = 6, num_attention_heads: int = 4, intermediate_size: int = 512, hidden_act: str = 'relu', hidden_dropout_prob: float = 0.02, attention_probs_dropout_prob: float = 0.02, max_position_embeddings: int = 2048, type_vocab_size: int = 2, initializer_range: float = 0.02, position_embedding_type: str = 'absolute', layer_norm_eps: float = 1e-12, mlm_probability: float = 0.15)[source]

Bases: CellariumModel, PredictMixin

Geneformer model.

References:

  1. Transfer learning enables predictions in network biology (Theodoris et al.).

Parameters:
  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • hidden_size (int) – Dimensionality of the encoder layers and the pooler layer.

  • num_hidden_layers (int) – Number of hidden layers in the Transformer encoder.

  • num_attention_heads (int) – Number of attention heads for each attention layer in the Transformer encoder.

  • intermediate_size (int) – Dimensionality of the “intermediate” (often named feed-forward) layer in the Transformer encoder.

  • hidden_act (str) – The non-linear activation function (function or string) in the encoder and pooler. If string, "gelu", "relu", "silu" and "gelu_new" are supported.

  • hidden_dropout_prob (float) – The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

  • attention_probs_dropout_prob (float) – The dropout ratio for the attention probabilities.

  • max_position_embeddings (int) – The maximum sequence length that this model might ever be used with. Typically set this to something large just in case (e.g., 512 or 1024 or 2048).

  • type_vocab_size (int) – The vocabulary size of the token_type_ids passed when calling transformers.BertModel.

  • initializer_range (float) – The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

  • position_embedding_type (str) – Type of position embedding. Choose one of "absolute", "relative_key", "relative_key_query". For positional embeddings use "absolute". For more information on "relative_key", please refer to Self-Attention with Relative Position Representations (Shaw et al.). For more information on "relative_key_query", please refer to Method 4 in Improve Transformer Models with Better Relative Position Embeddings (Huang et al.).

  • layer_norm_eps (float) – The epsilon used by the layer normalization layers.

  • mlm_probability (float) – Ratio of tokens to mask for masked language modeling loss.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor]

predict(x_ng: Tensor, var_names_g: ndarray, output_hidden_states: bool = True, output_attentions: bool = True, output_input_ids: bool = True, output_attention_mask: bool = True, feature_activation: list[str] | None = None, feature_deletion: list[str] | None = None, feature_map: dict[str, int] | None = None) dict[str, Tensor | ndarray][source]

Send (transformed) data through the model and return outputs. Optionally perform in silico perturbations and masking.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

  • output_hidden_states (bool) – Whether to return all hidden-states.

  • output_attentions (bool) – Whether to return all attentions.

  • output_input_ids (bool) – Whether to return input ids.

  • output_attention_mask (bool) – Whether to return attention mask.

  • feature_activation (list[str] | None) – Specify features whose expression should be set to > max(x_ng) before tokenization (top rank).

  • feature_deletion (list[str] | None) – Specify features whose expression should be set to zero before tokenization (remove from inputs).

  • feature_map (dict[str, int] | None) – Specify a mapping for input tokens, to be applied before model.

Returns:

A dictionary with the inference results.

Return type:

dict[str, Tensor | ndarray]

Note

In silico perturbations can be achieved in one of three ways:

  1. Use feature_map to replace a feature token with MASK (1) or PAD (0) (e.g. feature_map={"ENSG0001": 1} will replace var_names_g feature ENSG0001 with a MASK token).

  2. Use feature_deletion to remove a feature from the cell’s inputs, which instead of adding a PAD or MASK token, will allow another feature to take its place (e.g. feature_deletion=["ENSG0001"] will remove var_names_g feature ENSG0001 from the input, and allow a new feature token to take its place).

  3. Use feature_activation to move a feature all the way to the top rank position in the input (e.g. feature_activation=["ENSG0001"] will make var_names_g feature ENSG0001 the first in rank order. Multiple input features will be ranked according to their order in the input list).

Number (2) and (3) are described in the Geneformer paper under “In silico perturbation” in the Methods section.

class cellarium.ml.models.IncrementalPCA(var_names_g: ndarray, n_components: int, svd_lowrank_niter: int = 2, perform_mean_correction: bool = True)[source]

Bases: CellariumModel, PredictMixin

Distributed and Incremental PCA.

References:

  1. A Distributed and Incremental SVD Algorithm for Agglomerative Data Analysis on Large Networks (Iwen et al.).

  2. Incremental Learning for Robust Visual Tracking (Ross et al.).

Parameters:
  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • n_components (int) – Number of principal components.

  • svd_lowrank_niter (int) – Number of iterations for the low-rank SVD algorithm.

  • perform_mean_correction (bool) – If True then the mean correction is applied to the update step. If False then the data is assumed to be centered and the mean correction is not applied to the update step.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]

Incrementally update partial SVD with new data.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

An empty dictionary.

Return type:

dict[str, Tensor | None]

on_train_epoch_end(trainer: Trainer) None[source]

Merge partial SVD results from parallel processes at the end of the epoch.

Merging SVDs is performed hierarchically. At each merging level, the leading process (even rank) merges its SVD with the trailing process (odd rank). The trailing process discards its SVD and is terminated. The leading process continues to the next level. This process continues until only one process remains. The final SVD is stored on the remaining process.

The number of levels (hierarchy depth) scales logarithmically with the number of processes.

Parameters:

trainer (Trainer)

Return type:

None

property explained_variance_k: Tensor

The amount of variance explained by each of the selected components. The variance estimation uses x_size degrees of freedom.

Equal to n_components largest eigenvalues of the covariance matrix of input data.

property components_kg: Tensor

Principal axes in feature space, representing the directions of maximum variance in the data. Equivalently, the right singular vectors of the centered input data, parallel to its eigenvectors. The components are sorted by decreasing explained_variance_k.

predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]

Centering and embedding of the input data x_ng into the principal component space.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

  • x_ng: Embedding of the input data into the principal component space.

  • var_names_g: The list of variable names for the output data.

Return type:

A dictionary with the following keys

class cellarium.ml.models.LogisticRegression(n_obs: int, var_names_g: ndarray, y_categories: ndarray, W_prior_scale: float = 1.0, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True)[source]

Bases: CellariumModel, PredictMixin

Logistic regression model.

Parameters:
  • n_obs (int) – Number of observations.

  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • y_categories (ndarray) – The categories for the target data.

  • W_prior_scale (float) – The scale of the Laplace prior for the weights.

  • W_init_scale (float) – Initialization scale for the W_gc parameter.

  • seed (int) – Random seed used to initialize parameters.

  • log_metrics (bool) – Whether to log the histogram of the W_gc parameter.

forward(x_ng: Tensor, var_names_g: ndarray, y_n: Tensor, y_categories: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – The input data.

  • var_names_g (ndarray) – The variable names for the input data.

  • y_n (Tensor) – The target data.

  • y_categories (ndarray) – The categories for the input target data.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor | None]

predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]

Predict the target logits.

Parameters:
  • x_ng (Tensor) – The input data.

  • var_names_g (ndarray) – The variable names for the input data.

Returns:

A dictionary with the target logits.

Return type:

dict[str, ndarray | Tensor]

class cellarium.ml.models.abcdParameter(data: Tensor | None = None, width: int = 1, a: float = 0.0, b: float = 0.0, d: float = 0.0, init_std: float = 1.0, lr_scale: float = 1.0, base_width: int = 1)[source]

Bases: object

An abcd-parametrization describes the scaling of a parameter \(W\) with width \(n\). The parameter is initialized with a standard deviation \(\sigma\) and a parameter-specific learning rate scaling factor \(\alpha\) at base width \(n_0\). The scaling of the parameterization with width \(n\) is described by a set of numbers \(\{a, b, c, d\}\) such that:

  1. Parameter is given as \(W = \sqrt{\alpha} \cdot (n_0 / n)^a \cdot w\) where \(w\) is the learnable parameter.

  2. Learnable parameter is initialized as \(w \sim \mathcal{N}(0, \sigma \cdot (n_0 / n)^b / \sqrt{\alpha})\).

  3. The effective learning rate for \(W\) is \(\alpha \cdot (n_0 / n)^{2a} \cdot (n_0 / n)^c \cdot \eta\) for some global learning rate \(\eta\). In this implementation, \(c\) is equal to \(0\).

  4. The gradients of \(w\) are scaled by \((n / n_0)^d\).

Parameters:
  • data (Tensor | None) – The tensor data.

  • width (int) – The width \(n\) of the tensor.

  • a (float) – The \(a\) parameter.

  • b (float) – The \(b\) parameter.

  • d (float) – The \(d\) parameter.

  • init_std (float) – The initialization standard deviation \(\sigma\) at base width \(n_0\).

  • lr_scale (float) – The learning rate scale factor \(\alpha\) at base width \(n_0\).

  • base_width (int) – The base width \(n_0\).

class cellarium.ml.models.MuLinear(in_features: int, out_features: int, bias: bool, layer: Literal['input', 'hidden', 'output'], optimizer: Literal['sgd', 'adam', 'adamw'], weight_init_std: float = 1.0, bias_init_std: float = 0.0, lr_scale: float = 1.0, base_width: int = 1)[source]

Bases: Module

Linear layer with a maximal update parametrization.

The maximal update parametrization for SGD is defined by:

Input & Biases

Hidden

Output

\(a\)

\(-0.5\)

\(0\)

\(0.5\)

\(b\)

\(0.5\)

\(0.5\)

\(0.5\)

\(c\)

\(0\)

\(0\)

\(0\)

\(d\)

\(0\)

\(0\)

\(0\)

\(n\)

out_features

in_features

in_features

The maximal update parametrization for Adam and AdamW is defined by

Input & Biases

Hidden

Output

\(a\)

\(0\)

\(1\)

\(1\)

\(b\)

\(0\)

\(-0.5\)

\(0\)

\(c\)

\(0\)

\(0\)

\(0\)

\(d\)

\(1\)

\(2\)

\(1\)

\(n\)

out_features

in_features

in_features

Since in this implementation \(c\) always equals 0, regular PyTorch optimizers can be used.

References:

  1. Feature Learning in Infinite-Width Neural Networks (Yang et al.).

  2. Tensor Programs IVb: Adaptive Optimization in the ∞-Width Limit (Yang et al.).

Parameters:
  • in_features (int) – Size of each input sample.

  • out_features (int) – Size of each output sample.

  • bias (bool) – If set to False, the layer will not learn an additive bias.

  • layer (Literal['input', 'hidden', 'output']) – Layer type.

  • optimizer (Literal['sgd', 'adam', 'adamw']) – Optimizer type.

  • weight_init_std (float) – The standard deviation of the weight initialization at base width.

  • bias_init_std (float) – The standard deviation of the bias initialization at base width.

  • lr_scale (float) – The learning rate scaling factor for the weight and the bias.

  • base_width (int) – The base width of the layer.

property weight: Tensor

The weights of the module of shape (out_features, in_features). The weight-specific learning rate and the initialization standard deviation are scaled with the width of the layer according to the table above.

property bias: Tensor | None

The bias of the module of shape (out_features). If bias is True, the bias-specific learning rate and the initialization standard deviation are scaled with the width of the layer according to the table above.

class cellarium.ml.models.OnePassMeanVarStd(var_names_g: ndarray, algorithm: Literal['naive', 'shifted_data'] = 'naive')[source]

Bases: CellariumModel

Calculate the mean, variance, and standard deviation of the data in one pass (epoch) using running sums and running squared sums.

References:

  1. Algorithms for calculating variance.

Parameters:
  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • algorithm (Literal['naive', 'shifted_data'])

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

An empty dictionary.

Return type:

dict[str, Tensor | None]

property mean_g: Tensor

Mean of the data.

property var_g: Tensor

Variance of the data.

property std_g: Tensor

Standard deviation of the data.

class cellarium.ml.models.PredictMixin[source]

Bases: object

Abstract mixin class for models that can perform prediction.

abstract predict(*args: Any, **kwargs: Any) dict[str, ndarray | Tensor][source]

Perform prediction.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

dict[str, ndarray | Tensor]

class cellarium.ml.models.ProbabilisticPCA(n_obs: int, var_names_g: ndarray, n_components: int, ppca_flavor: Literal['marginalized', 'linear_vae'], mean_g: Tensor | None = None, W_init_scale: float = 1.0, sigma_init_scale: float = 1.0, seed: int = 0)[source]

Bases: CellariumModel, PredictMixin

Probabilistic PCA implemented in Pyro.

Two flavors of probabilistic PCA are available - marginalized pPCA [1] and linear VAE [2].

References:

  1. Probabilistic Principal Component Analysis (Tipping et al.).

  2. Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.).

Parameters:
  • n_obs (int) – Number of cells.

  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • n_components (int) – Number of principal components.

  • ppca_flavor (Literal['marginalized', 'linear_vae']) – Type of the PPCA model. Has to be one of marginalized or linear_vae.

  • mean_g (Tensor | None) – Mean gene expression of the input data. If None then the mean is set to a learnable parameter.

  • W_init_scale (float) – Scale of the random initialization of the W_kg parameter.

  • sigma_init_scale (float) – Initialization value of the sigma parameter.

  • seed (int) – Random seed used to initialize parameters.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

A dictionary with the loss value.

Return type:

dict[str, Tensor | None]

predict(x_ng: Tensor, var_names_g: ndarray) dict[str, ndarray | Tensor][source]

Centering and embedding of the input data x_ng into the principal component space.

Note

Gradients are disabled, used for inference only.

Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

  • z_nk: Embedding of the input data into the principal component space.

Return type:

A dictionary with the following keys

property L_k: Tensor

Vector with elements given by the PC eigenvalues.

Note

Gradients are disabled, used for inference only.

property U_gk: Tensor

Principal components corresponding to eigenvalues L_k.

Note

Gradients are disabled, used for inference only.

property W_variance: float

Note

Gradients are disabled, used for inference only.

property sigma_variance: float

Note

Gradients are disabled, used for inference only.

class cellarium.ml.models.TDigest(var_names_g: ndarray)[source]

Bases: CellariumModel

Compute an approximate non-zero histogram of the distribution of each gene in a batch of cells using t-digests.

References:

  1. Computing Extremely Accurate Quantiles Using T-Digests (Dunning et al.).

Parameters:

var_names_g (ndarray) – The variable names schema for the input data validation.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | None][source]
Parameters:
  • x_ng (Tensor) – Gene counts matrix.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

An empty dictionary.

Return type:

dict[str, Tensor | None]

property median_g: Tensor

Median of the data.

class cellarium.ml.models.ValidateMixin[source]

Bases: object

Abstract mixin class for models that can perform validation.

abstract validate(*args: Any, **kwargs: Any) None[source]

Perform validation.

Parameters:
  • args (Any)

  • kwargs (Any)

Return type:

None