Transforms
- class cellarium.ml.transforms.BinomialResample(p_binom_min: float, p_binom_max: float, p_apply: float)[source]
Bases:
ModuleBinomial resampling of gene counts.
For each count, the parameter to the binomial distribution is independently and uniformly sampled according to the bounding parameters, yielding the parameter matrix p_ng.
\[y_{ng} = Binomial(n=x_{ng}, p=p_{ng})\]- Parameters:
p_binom_min (float) – Lower bound on binomial distribution parameter.
p_binom_max (float) – Upper bound on binomial distribution parameter.
p_apply (float) – Probability of applying transform to each sample.
- class cellarium.ml.transforms.CellariumGPTTrainTokenizer(context_len: int, gene_downsample_fraction: float, min_total_mrna_umis: int, max_total_mrna_umis: int, gene_vocab_sizes: dict[str, int], metadata_vocab_sizes: dict[str, int], ontology_downsample_p: float, ontology_infos_path: str, prefix_len: int | None = None, metadata_prompt_token_list: list[str] | None = None, obs_names_rng: bool = False)[source]
Bases:
ModuleTokenizer for the Cellarium GPT model.
- Parameters:
context_len (int) – Context length.
gene_downsample_fraction (float) – Fraction of genes to downsample.
min_total_mrna_umis (int) – Minimum total mRNA UMIs.
max_total_mrna_umis (int) – Maximum total mRNA UMIs.
gene_vocab_sizes (dict[str, int]) – Gene token vocabulary sizes.
metadata_vocab_sizes (dict[str, int]) – Metadata token vocabulary sizes.
ontology_infos_path (str) – Path to ontology information.
prefix_len (int | None) – Prefix length. If
None, the prefix length is sampled.metadata_prompt_token_list (list[str] | None) – List of metadata tokens to prompt. If
None, the metadata prompt tokens are sampled.obs_names_rng (bool) – Cell IDs are used as random seeds for shuffling gene tokens. If
None, gene tokens are shuffled without a random seed.ontology_downsample_p (float)
- class cellarium.ml.transforms.DivideByScale(scale_g: Tensor, var_names_g: ndarray, eps: float = 1e-06)[source]
Bases:
FilterCompatibilityMixin,ModuleDivide gene counts by a scale.
\[y_{ng} = \frac{x_{ng}}{\mathrm{scale}_g + \mathrm{eps}}\]- Parameters:
scale_g (Tensor) – A scale for each gene.
var_names_g (ndarray) – The variable names schema for the input data validation.
eps (float) – A value added to the denominator for numerical stability.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor][source]
- Parameters:
x_ng (Tensor) – Gene counts.
var_names_g (ndarray) – The list of the variable names in the input data. Must be a subset of (or equal to) the
var_names_gschema the transform was initialized with, in any order.
- Returns:
x_ng: The gene counts divided by the scale.
- Return type:
A dictionary with the following keys
- class cellarium.ml.transforms.Dropout(p_dropout_min, p_dropout_max, p_apply)[source]
Bases:
ModuleApplies random dropout to gene counts.
For each count, the dropout parameter is independently and uniformly sampled according to the bounding parameters, yielding the parameter matrix p_ng.
\[y_{ng} = x_{ng} * (1 - Bernoulli(p_ng))\]- Parameters:
p_dropout_min – Lower bound on dropout parameter.
p_dropout_max – Upper bound on dropout parameter.
p_apply – Probability of applying transform to each sample.
- class cellarium.ml.transforms.Duplicate(enabled=True)[source]
Bases:
ModuleDuplicates every row of the input tensor, used for contrastive augmentations.
- class cellarium.ml.transforms.Filter(filter_list: Sequence[str], ordering: bool = True, allow_missing: bool = False)[source]
Bases:
ModuleFilter gene counts by a list of features.
When
ordering=False, the output columns follow the order genes appear in the inputvar_names_g:\[ \begin{align}\begin{aligned}\mathrm{mask}_g = \mathrm{feature}_g \in \mathrm{filter\_list}\\y_{ng} = x_{ng}[:, \mathrm{mask}_g]\end{aligned}\end{align} \]When
ordering=True(default), the output columns follow the order offilter_list:\[y_{ng} = x_{ng}[:, \sigma(\mathrm{filter\_list})]\]where \(\sigma\) maps each entry in
filter_listto its column index in the input.- Parameters:
filter_list (Sequence[str]) – A list of features to filter by.
ordering (bool) – If
True(default), output columns are ordered to matchfilter_list. IfFalse, output columns follow the order genes appear in the inputvar_names_g. Useordering=Truewhen running inference on data with a different gene ordering than seen during training.allow_missing (bool) – If
True, genes infilter_listthat are absent from the input are zero-filled in the output. Requiresordering=True. IfFalse(default), all genes infilter_listmust be present in the input.
- filter(var_names_g: tuple) ndarray[Any, dtype[int64]] | tuple[ndarray[Any, dtype[int64]], ndarray[Any, dtype[int64]]][source]
- Parameters:
var_names_g (tuple) – The list of the variable names in the input data.
- Returns:
a 1-D array of source indices in input order.
When
ordering=Trueandallow_missing=False: a 1-D array of source indices ordered to matchfilter_list.When
ordering=Trueandallow_missing=True: a tuple(src_indices, out_indices)wheresrc_indicesindexes columns invar_names_gandout_indicesgives the corresponding destination column in the output (which always haslen(filter_list)columns).- Return type:
When
ordering=False
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor | ndarray][source]
Note
When used with
CellariumModuleorCellariumPipeline,x_ngandvar_names_gkeys in the input dictionary will be overwritten with the filtered values.- Parameters:
x_ng (Tensor) – Gene counts.
var_names_g (ndarray) – The list of the variable names in the input data.
- Returns:
x_ng: Gene counts filtered (and reordered ifordering=True) to matchfilter_list. Shape(n, len(filter_list))whenordering=True, otherwise(n, num_matched).var_names_g: Gene names corresponding to the output columns.
- Return type:
A dictionary with the following keys
- class cellarium.ml.transforms.GaussianNoise(sigma_min, sigma_max, p_apply)[source]
Bases:
ModuleAdds Gaussian noise to gene counts.
For each count, Gaussian sigma is independently and uniformly sampled according to the bounding parameters, yielding the sigma matrix sigma_ng.
\[y_{ng} = x_{ng} + N(0, \sigma_{ng})\]- Parameters:
sigma_min – Lower bound on Gaussian sigma parameter.
sigma_max – Upper bound on Gaussian sigma parameter.
p_apply – Probability of applying transform to each sample.
- class cellarium.ml.transforms.Log1p(*args: Any, **kwargs: Any)[source]
Bases:
ModuleLog1p transform gene counts.
\[y_{ng} = \log(1 + x_{ng})\]- Parameters:
args (Any)
kwargs (Any)
- forward(x_ng: Tensor) dict[str, Tensor][source]
Note
When used with
CellariumModuleorCellariumPipeline,x_ngkey in the input dictionary will be overwritten with the log1p transformed values.- Parameters:
x_ng (Tensor) – Gene counts.
- Returns:
x_ng: The log1p transformed gene counts.
- Return type:
A dictionary with the following keys
- class cellarium.ml.transforms.NormalizeTotal(target_count: int = 10000, eps: float = 1e-06)[source]
Bases:
ModuleNormalize total gene counts per cell to target count.
\[ \begin{align}\begin{aligned}\mathrm{total\_mrna\_umis}_n = \sum_{g=1}^G x_{ng}\\y_{ng} = \frac{\mathrm{target\_count} \times x_{ng}}{\mathrm{total\_mrna\_umis}_n + \mathrm{eps}}\end{aligned}\end{align} \]- Parameters:
target_count (int) – Target gene epxression count.
eps (float) – A value added to the denominator for numerical stability.
- forward(x_ng: Tensor, total_mrna_umis_n: Tensor | None = None) dict[str, Tensor][source]
Note
When used with
CellariumModuleorCellariumPipeline,x_ngkey in the input dictionary will be overwritten with the normalized values.- Parameters:
x_ng (Tensor) – Gene counts.
total_mrna_umis_n (Tensor | None) – Total mRNA UMI counts per cell. If
None, it is computed fromx_ng.
- Returns:
x_ng: The gene counts normalized to target count.
- Return type:
A dictionary with the following keys
- class cellarium.ml.transforms.ZScore(mean_g: Tensor, std_g: Tensor, var_names_g: ndarray, eps: float = 1e-06)[source]
Bases:
FilterCompatibilityMixin,ModuleZScore gene counts with mean and standard deviation.
\[y_{ng} = \frac{x_{ng} - \mathrm{mean}_g}{\mathrm{std}_g + \mathrm{eps}}\]- Parameters:
mean_g (Tensor) – Means for each gene.
std_g (Tensor) – Standard deviations for each gene.
var_names_g (ndarray) – The variable names schema for the input data validation.
eps (float) – A value added to the denominator for numerical stability.
- forward(x_ng: Tensor, var_names_g: ndarray) dict[str, Tensor][source]
Note
When used with
CellariumModuleorCellariumPipeline,x_ngkey in the input dictionary will be overwritten with the z-scored values.- Parameters:
x_ng (Tensor) – Gene counts.
var_names_g (ndarray) – The list of the variable names in the input data. Must be a subset of (or equal to) the
var_names_gschema the transform was initialized with, in any order.
- Returns:
x_ng: The z-scored gene counts.
- Return type:
A dictionary with the following keys