Transforms

class cellarium.ml.transforms.CellariumGPTTrainTokenizer(context_len: int, gene_downsample_fraction: float, min_total_mrna_umis: int, max_total_mrna_umis: int, gene_vocab_sizes: dict[str, int], metadata_vocab_sizes: dict[str, int], ontology_downsample_p: float, ontology_infos_path: str, prefix_len: int | None = None, metadata_prompt_token_list: list[str] | None = None, obs_names_rng: bool = False)[source]

Bases: Module

Tokenizer for the Cellarium GPT model.

Parameters:
  • context_len (int) – Context length.

  • gene_downsample_fraction (float) – Fraction of genes to downsample.

  • min_total_mrna_umis (int) – Minimum total mRNA UMIs.

  • max_total_mrna_umis (int) – Maximum total mRNA UMIs.

  • gene_vocab_sizes (dict[str, int]) – Gene token vocabulary sizes.

  • metadata_vocab_sizes (dict[str, int]) – Metadata token vocabulary sizes.

  • ontology_infos_path (str) – Path to ontology information.

  • prefix_len (int | None) – Prefix length. If None, the prefix length is sampled.

  • metadata_prompt_token_list (list[str] | None) – List of metadata tokens to prompt. If None, the metadata prompt tokens are sampled.

  • obs_names_rng (bool) – Cell IDs are used as random seeds for shuffling gene tokens. If None, gene tokens are shuffled without a random seed.

  • ontology_downsample_p (float)

class cellarium.ml.transforms.DivideByScale(scale_g: Tensor, var_names_g: ndarray, eps: float = 1e-06)[source]

Bases: Module

Divide gene counts by a scale.

\[y_{ng} = \frac{x_{ng}}{\mathrm{scale}_g + \mathrm{eps}}\]
Parameters:
  • scale_g (Tensor) – A scale for each gene.

  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • eps (float) – A value added to the denominator for numerical stability.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor][source]
Parameters:
  • x_ng (Tensor) – Gene counts.

  • var_names_g (ndarray) – The list of the variable names in the input data. If None, no validation is performed.

Returns:

  • x_ng: The gene counts divided by the scale.

Return type:

A dictionary with the following keys

class cellarium.ml.transforms.Filter(filter_list: Sequence[str])[source]

Bases: Module

Filter gene counts by a list of features.

\[ \begin{align}\begin{aligned}\mathrm{mask}_g = \mathrm{feature}_g \in \mathrm{filter\_list}\\y_{ng} = x_{ng}[:, \mathrm{mask}_g]\end{aligned}\end{align} \]
Parameters:

filter_list (Sequence[str]) – A list of features to filter by.

filter(var_names_g: tuple) ndarray[Any, dtype[int64]][source]
Parameters:

var_names_g (tuple) – The list of the variable names in the input data.

Returns:

An array of indices of the features in var_names_g that are in filter_list.

Return type:

ndarray[Any, dtype[int64]]

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | ndarray][source]

Note

When used with CellariumModule or CellariumPipeline, x_ng and var_names_g keys in the input dictionary will be overwritten with the filtered values.

Parameters:
  • x_ng (Tensor) – Gene counts.

  • var_names_g (ndarray) – The list of the variable names in the input data.

Returns:

  • x_ng: Gene counts filtered by filter_list.

  • var_names_g: The list of the variable names in the input data filtered by filter_list.

Return type:

A dictionary with the following keys

class cellarium.ml.transforms.Log1p(*args, **kwargs)[source]

Bases: Module

Log1p transform gene counts.

\[y_{ng} = \log(1 + x_{ng})\]
forward(x_ng: Tensor) dict[str, Tensor][source]

Note

When used with CellariumModule or CellariumPipeline, x_ng key in the input dictionary will be overwritten with the log1p transformed values.

Parameters:

x_ng (Tensor) – Gene counts.

Returns:

  • x_ng: The log1p transformed gene counts.

Return type:

A dictionary with the following keys

class cellarium.ml.transforms.NormalizeTotal(target_count: int = 10000, eps: float = 1e-06)[source]

Bases: Module

Normalize total gene counts per cell to target count.

\[ \begin{align}\begin{aligned}\mathrm{total\_mrna\_umis}_n = \sum_{g=1}^G x_{ng}\\y_{ng} = \frac{\mathrm{target\_count} \times x_{ng}}{\mathrm{total\_mrna\_umis}_n + \mathrm{eps}}\end{aligned}\end{align} \]
Parameters:
  • target_count (int) – Target gene epxression count.

  • eps (float) – A value added to the denominator for numerical stability.

forward(x_ng: Tensor, total_mrna_umis_n: Tensor | None = None) dict[str, Tensor][source]

Note

When used with CellariumModule or CellariumPipeline, x_ng key in the input dictionary will be overwritten with the normalized values.

Parameters:
  • x_ng (Tensor) – Gene counts.

  • total_mrna_umis_n (Tensor | None) – Total mRNA UMI counts per cell. If None, it is computed from x_ng.

Returns:

  • x_ng: The gene counts normalized to target count.

Return type:

A dictionary with the following keys

class cellarium.ml.transforms.ZScore(mean_g: Tensor, std_g: Tensor, var_names_g: ndarray, eps: float = 1e-06)[source]

Bases: Module

ZScore gene counts with mean and standard deviation.

\[y_{ng} = \frac{x_{ng} - \mathrm{mean}_g}{\mathrm{std}_g + \mathrm{eps}}\]
Parameters:
  • mean_g (Tensor) – Means for each gene.

  • std_g (Tensor) – Standard deviations for each gene.

  • var_names_g (ndarray) – The variable names schema for the input data validation.

  • eps (float) – A value added to the denominator for numerical stability.

forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor][source]

Note

When used with CellariumModule or CellariumPipeline, x_ng key in the input dictionary will be overwritten with the z-scored values.

Parameters:
  • x_ng (Tensor) – Gene counts.

  • var_names_g (ndarray) – The list of the variable names in the input data. If None, no validation is performed.

Returns:

  • x_ng: The z-scored gene counts.

Return type:

A dictionary with the following keys