# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
"""
Command line interface for Cellarium ML.
"""
import copy
import sys
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from functools import cache
from operator import attrgetter
from typing import Any, cast
import numpy as np
import pandas as pd
import torch
import yaml
from jsonargparse import Namespace, class_from_function
from jsonargparse._loaders_dumpers import DefaultLoader
from jsonargparse._util import import_object
from lightning.pytorch.cli import ArgsType, LightningArgumentParser, LightningCLI
from torch.utils._pytree import tree_map
from cellarium.ml import CellariumAnnDataDataModule, CellariumModule, CellariumPipeline
from cellarium.ml.utilities.data import AnnDataField, collate_fn
cached_loaders: dict[Callable[[str], Any] | str, Callable[[str], Any]] = {}
def _resolve_loader(loader_fn: Callable[[str], Any] | str) -> Callable[[str], Any]:
if isinstance(loader_fn, str):
loader_fn = import_object(loader_fn)
if loader_fn not in cached_loaders:
cached_loaders[loader_fn] = cache(loader_fn) # type: ignore[arg-type]
return cached_loaders[loader_fn]
def _resolve_value(
obj: Any,
attr: str | None = None,
key: Any = None,
convert_fn: Callable[[Any], Any] | str | None = None,
) -> Any:
if attr is not None:
obj = attrgetter(attr)(obj)
if key is not None:
obj = obj[key]
if isinstance(convert_fn, str):
convert_fn = import_object(convert_fn)
if convert_fn is not None:
obj = convert_fn(obj) # type: ignore[operator]
return obj
[docs]
@dataclass
class FileLoader:
"""
A YAML constructor for loading a file and accessing its attributes.
Example:
.. code-block:: yaml
model:
cpu_transforms:
- class_path: cellarium.ml.transforms.Filter
init_args:
filter_list:
!FileLoader
file_path: gs://dsp-cellarium-cas-public/test-data/filter_list.csv
loader_fn: pandas.read_csv
attr: index
convert_fn: numpy.ndarray.tolist
Args:
file_path:
The file path to load the object from.
loader_fn:
A function to load the object from the file path.
attr:
An attribute to get from the loaded object. If ``None`` the loaded object is returned.
convert_fn:
A function to convert the loaded object. If ``None`` the loaded object is returned.
"""
file_path: str
loader_fn: Callable[[str], Any] | str
attr: str | None = None
key: str | None = None
convert_fn: Callable[[Any], Any] | str | None = None
def __new__(cls, file_path, loader_fn, attr=None, key=None, convert_fn=None):
obj = _resolve_loader(loader_fn)(file_path)
return _resolve_value(obj, attr=attr, key=key, convert_fn=convert_fn)
[docs]
@dataclass
class FileMultiLoader:
"""
A YAML constructor for loading a file once and extracting multiple fields from it.
Applied to a ``init_args`` mapping, it returns a dict that is unpacked as keyword
arguments to the constructor, allowing all fields to be specified in a single block
instead of repeating the file path and loader for each argument.
Example:
.. code-block:: yaml
model:
transforms:
- class_path: cellarium.ml.transforms.ZScore
init_args:
!FileMultiLoader
file_path: /tmp/test_examples/onepass/onepass.csv
loader_fn: pandas.read_csv
fields:
mean_g:
attr: mean_g.values
convert_fn: torch.FloatTensor
std_g:
attr: std_g.values
convert_fn: torch.FloatTensor
var_names_g:
attr: var_names_g
convert_fn: pandas.Series.to_numpy
Args:
file_path:
The file path to load the object from.
loader_fn:
A function to load the object from the file path.
fields:
A mapping from output key names to field specs. Each spec may contain:
``attr`` (dotted attribute path), ``key`` (item key), and ``convert_fn``
(importable callable string or ``None``).
"""
file_path: str
loader_fn: Callable[[str], Any] | str
fields: dict[str, dict]
def __new__(cls, file_path, loader_fn, fields) -> dict: # type: ignore[misc]
obj = _resolve_loader(loader_fn)(file_path)
return {name: _resolve_value(obj, **spec) for name, spec in fields.items()}
[docs]
@dataclass
class CheckpointLoader(FileLoader):
"""
A YAML constructor for loading a :class:`~cellarium.ml.core.CellariumModule` checkpoint and accessing its
attributes.
Example:
.. code-block:: yaml
model:
transorms:
- class_path: cellarium.ml.transforms.DivideByScale
init_args:
scale_g:
!CheckpointLoader
file_path: gs://dsp-cellarium-cas-public/test-data/tdigest.ckpt
attr: model.median_g
convert_fn: null
Args:
file_path:
The file path to load the object from.
attr:
An attribute to get from the loaded object. If ``None`` the loaded object is returned.
convert_fn:
A function to convert the loaded object. If ``None`` the loaded object is returned.
"""
file_path: str
attr: str | None = None
key: str | None = None
convert_fn: Callable[[Any], Any] | str | None = None
def __new__(cls, file_path, attr=None, key=None, convert_fn=None):
return super().__new__(cls, file_path, CellariumModule.load_from_checkpoint, attr, key, convert_fn)
[docs]
def file_loader_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode) -> FileLoader:
"""Construct an object from a file."""
return FileLoader(**loader.construct_mapping(node)) # type: ignore[arg-type]
[docs]
def file_multi_loader_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode) -> dict:
"""Construct a dict of objects from a file."""
return FileMultiLoader(**loader.construct_mapping(node, deep=True)) # type: ignore[arg-type, return-value]
[docs]
def checkpoint_loader_constructor(loader: yaml.SafeLoader, node: yaml.nodes.MappingNode) -> CheckpointLoader:
"""Construct an object from a checkpoint."""
return CheckpointLoader(**loader.construct_mapping(node)) # type: ignore[arg-type]
loader = DefaultLoader
loader.add_constructor("!FileLoader", file_loader_constructor)
loader.add_constructor("!FileMultiLoader", file_multi_loader_constructor)
loader.add_constructor("!CheckpointLoader", checkpoint_loader_constructor)
REGISTERED_MODELS: dict[str, Callable[[ArgsType], None]] = {}
def register_model(model: Callable[[ArgsType], None]):
REGISTERED_MODELS[model.__name__] = model
return model
CellariumModuleLoadFromCheckpoint = class_from_function(CellariumModule.load_from_checkpoint, CellariumModule)
[docs]
@dataclass
class LinkArguments:
"""
Arguments for linking the value of a target argument to the values of one or more source arguments.
Args:
source:
Key(s) from which the target value is derived.
target:
Key to where the value is set.
compute_fn:
Function to compute target value from source.
apply_on:
At what point to set target value, ``"parse"`` or ``"instantiate"``.
"""
source: str | tuple[str, ...]
target: str
compute_fn: Callable | None = None
apply_on: str = "instantiate"
[docs]
def compute_n_obs(data: CellariumAnnDataDataModule) -> int:
"""
Compute the number of observations in the data.
Args:
data: A :class:`CellariumAnnDataDataModule` instance.
Returns:
The number of observations in the data.
"""
return data.dadc.n_obs
[docs]
def compute_n_vars(data: CellariumAnnDataDataModule) -> int:
"""
Compute the number of observations in the data.
Args:
data: A :class:`CellariumAnnDataDataModule` instance.
Returns:
The number of variables in the data.
"""
return data.dadc.n_vars
[docs]
def compute_y_categories(data: CellariumAnnDataDataModule) -> np.ndarray:
"""
Compute the categories in the target variable.
E.g. if the target variable is ``obs["cell_type"]`` then this function
returns the categories in ``obs["cell_type"]``::
>>> np.asarray(data.dadc[0].obs["cell_type"].cat.categories)
Args:
data: A :class:`CellariumAnnDataDataModule` instance.
Returns:
The categories in the target variable.
"""
adata = data.dadc[0]
field = data.batch_keys["y_categories"]
assert isinstance(field, AnnDataField)
return field(adata)
[docs]
def compute_var_names_g(
cpu_transforms: list[torch.nn.Module] | None,
transforms: list[torch.nn.Module] | None,
data: CellariumAnnDataDataModule,
) -> np.ndarray:
"""
Compute variable names from the data by applying the transforms.
Args:
cpu_transforms:
A list of of CPU transforms applied by the dataloader.
transforms:
A list of transforms.
data:
A :class:`CellariumAnnDataDataModule` instance.
Returns:
The variable names.
"""
import scipy.sparse
from cellarium.ml.transforms.densify import Densify
adata = data.dadc[0]
batch = tree_map(lambda field: field(adata), data.batch_keys)
# Transforms may have been instantiated under torch.device("meta") by Lightning CLI
# (see CellariumModule.configure_model). Deep-copy each transform and materialise
# any meta-device parameters to uninitialised CPU tensors so the pipeline can run
# with real tensor inputs without calling reset_parameters.
all_transforms: list[torch.nn.Module] = []
for t in list(cpu_transforms or []) + list(transforms or []):
t_copy = copy.deepcopy(t)
if any(p.device.type == "meta" for p in t_copy.parameters()) or any(
b.device.type == "meta" for b in t_copy.buffers()
):
t_copy.to_empty(device="cpu")
all_transforms.append(t_copy)
# Pre-flight sparse configuration checks: inspect x_ng before running the
# pipeline to give actionable errors instead of confusing mid-pipeline failures.
# These checks also fire when compute_var_names_g is called directly from Python
# (without going through the CLI config validator in before_instantiate_classes).
collated = collate_fn([batch])
x_ng = collated.get("x_ng")
if x_ng is not None:
if scipy.sparse.issparse(x_ng):
# scipy sparse x_ng is valid only if Filter is in the transform pipeline
# (Filter converts scipy sparse → torch.sparse_csr_tensor).
# If there's no Filter, keep_sparse was used without a required cpu_transform.
has_filter = any(type(t).__name__ == "Filter" for t in all_transforms)
if not has_filter:
raise ValueError(
"Configuration Error: x_ng is a scipy sparse matrix in compute_var_names_g. "
"The `keep_sparse` convert_fn requires a Filter in `model.cpu_transforms` to convert "
"scipy sparse to torch.sparse_csr_tensor before this point.\n\n"
"Either:\n"
" (1) Add Filter to `model.cpu_transforms` with your gene filter_list, OR\n"
" (2) Switch to `convert_fn: cellarium.ml.utilities.data.to_torch_sparse_csr` and add "
"`Densify` as the first entry in `model.transforms`.\n"
)
if isinstance(x_ng, torch.Tensor) and x_ng.is_sparse_csr:
has_densify = any(isinstance(t, Densify) for t in all_transforms)
if not has_densify:
raise ValueError(
"Configuration Error: x_ng is a torch.sparse_csr_tensor but no `Densify` transform "
"was found in the combined cpu_transforms + transforms list.\n"
"`Densify` must be the first entry in `model.transforms` to convert "
"torch.sparse_csr_tensor to dense before any dense-only transforms run.\n\n"
"Add to model.transforms (as first item):\n"
" - cellarium.ml.transforms.Densify\n"
)
pipeline = CellariumPipeline(all_transforms)
try:
output = pipeline(collated)
except (RuntimeError, ValueError, TypeError) as e:
error_msg = str(e)
if "stride" in error_msg or "sparse" in error_msg.lower() or "csr_matrix" in error_msg:
raise ValueError(
f"compute_var_names_g pipeline failed: {error_msg}\n\n"
"This typically occurs when using the sparse data path (keep_sparse or to_torch_sparse_csr) "
"and Filter is incorrectly placed in model.transforms (GPU transforms) instead of "
"model.cpu_transforms (CPU transforms). When using sparse data, Filter MUST be in "
"model.cpu_transforms to filter and convert to torch.sparse_csr_tensor before GPU transfer. "
"Can also occur if Densify is not included as the first transform.\n"
"\n\nCorrect configuration:\n"
" model:\n"
" cpu_transforms:\n"
" - class_path: cellarium.ml.transforms.Filter\n"
" init_args:\n"
" filter_list: [...]\n"
" transforms:\n"
" - cellarium.ml.transforms.Densify\n"
" - cellarium.ml.transforms.NormalizeTotal\n"
" - ...\n"
" data:\n"
" batch_keys:\n"
" x_ng:\n"
" attr: X\n"
" convert_fn: cellarium.ml.utilities.data.keep_sparse"
) from e
raise
return output["var_names_g"]
[docs]
def compute_batch_index_n_categories(data: CellariumAnnDataDataModule) -> int:
"""
Compute the number of categories in batch_index_n.
.. note::
If batch_index_n is comprised of multiple keys, the number of categories is computed
as the product of the number of categories in each key.
Args:
data: A :class:`CellariumAnnDataDataModule` instance.
Returns:
The number of categories in batch_index_n.
"""
if "batch_index_n" not in data.batch_keys:
return 1 # for hvg selection when no batch key is given, treat all cells as a single batch
field = data.batch_keys["batch_index_n"]
assert isinstance(field, AnnDataField)
obs = getattr(data.dadc[0], field.attr)
x = obs[field.key]
if isinstance(x, pd.DataFrame):
return int(x.apply(lambda col: len(col.cat.categories)).product())
else:
return len(x.cat.categories)
def _get_transform_name(transform_spec: Any) -> str | None:
"""
Extract a canonical transform class name from a transform specification.
Handles both string paths, short names, class_path dicts, and instantiated modules.
Returns None if unable to extract.
"""
if isinstance(transform_spec, str):
# Already a string path like "cellarium.ml.transforms.Filter" or short name like "Filter"
return transform_spec.split(".")[-1] if "." in transform_spec else transform_spec
elif isinstance(transform_spec, dict):
# YAML class_path format: {"class_path": "cellarium.ml.transforms.Filter", "init_args": {...}}
if "class_path" in transform_spec:
class_path = transform_spec["class_path"]
return class_path.split(".")[-1]
elif isinstance(transform_spec, torch.nn.Module):
# Instantiated module
return transform_spec.__class__.__name__
return None
def _validate_sparse_config(config: dict[Any, Any] | Namespace) -> list[str]:
"""
Validate sparse data path configuration.
Raises ValueError for fatal misconfigurations, returns list of advisory warnings.
Checks for common misconfigurations like Filter in wrong transform list.
"""
warnings_list: list[str] = []
try:
# Convert Namespace to dict if needed
if isinstance(config, Namespace):
config = vars(config)
# Try to access model and data config; skip if not present (e.g., during --help)
if "model" not in config or "data" not in config:
return warnings_list
model_config = config.get("model", {})
data_config = config.get("data", {})
# Extract batch_keys and x_ng convert_fn
batch_keys = data_config.get("batch_keys", {})
x_ng_config = batch_keys.get("x_ng", {})
convert_fn = x_ng_config.get("convert_fn", "")
if isinstance(convert_fn, str):
convert_fn_name = convert_fn.split(".")[-1]
else:
convert_fn_name = ""
# Only validate if using sparse convert_fn
if convert_fn_name not in ("keep_sparse", "to_torch_sparse_csr"):
return warnings_list
# Extract cpu_transforms and transforms
cpu_transforms = model_config.get("cpu_transforms", [])
if cpu_transforms is None:
cpu_transforms = []
transforms = model_config.get("transforms", [])
if transforms is None:
transforms = []
# Extract transform names from both lists
cpu_transform_names = [_get_transform_name(t) for t in cpu_transforms]
cpu_transform_names = [n for n in cpu_transform_names if n] # filter None
transform_names = [_get_transform_name(t) for t in transforms]
transform_names = [n for n in transform_names if n] # filter None
# error: keep_sparse without Filter in cpu_transforms
if convert_fn_name == "keep_sparse" and "Filter" not in cpu_transform_names:
raise ValueError(
"Configuration Error: Using `keep_sparse` but Filter not found in `model.cpu_transforms`.\n"
"The `keep_sparse` convert_fn requires a Filter cpu_transform to filter sparse data and convert to "
"torch.sparse_csr_tensor BEFORE GPU transfer.\n\n"
"Either:\n"
" (1) Add Filter to `model.cpu_transforms` with your gene filter_list, OR\n"
" (2) Switch to `convert_fn: cellarium.ml.utilities.data.to_torch_sparse_csr` and ensure "
"`Densify` is the first entry in `model.transforms`.\n"
)
# error: Sparse path without Densify
if convert_fn_name in ("keep_sparse", "to_torch_sparse_csr") and "Densify" not in transform_names:
raise ValueError(
f"Configuration Error: Using sparse data path ({convert_fn_name}) but `Densify` not found in "
f"`model.transforms`.\n"
f"`Densify` must be the first entry in `model.transforms` to convert torch.sparse_csr_tensor to dense "
f"on GPU before any dense-only transforms run.\n\n"
f"Add to model.transforms (as first item):\n"
f" - cellarium.ml.transforms.Densify\n"
)
# warning: keep_sparse with Filter in wrong transform list
if convert_fn_name == "keep_sparse":
if "Filter" in transform_names and "Filter" not in cpu_transform_names:
warnings_list.append(
"Configuration Issue: Using `keep_sparse` with Filter in `model.transforms` (GPU transforms) "
"instead of `model.cpu_transforms` (CPU transforms). \n"
"When using `keep_sparse`, Filter MUST be in `model.cpu_transforms` to filter and convert to "
"torch.sparse_csr_tensor BEFORE GPU transfer. Placing Filter in GPU transforms defeats the "
"purpose of sparse transfer.\n\n"
"Correct configuration:\n"
" model:\n"
" cpu_transforms:\n"
" - class_path: cellarium.ml.transforms.Filter\n"
" init_args:\n"
" filter_list: [...]\n"
" transforms:\n"
" - cellarium.ml.transforms.Densify\n"
" - ...\n"
)
# warning: Densify not first in transforms
if convert_fn_name in ("keep_sparse", "to_torch_sparse_csr"):
if transform_names and transform_names[0] != "Densify" and "Densify" in transform_names:
densify_idx = transform_names.index("Densify")
warnings_list.append(
f"Configuration Issue: `Densify` found at position {densify_idx} in `model.transforms` but should "
f"be first (position 0). Dense-only transforms must run AFTER densification.\n"
f"Move `Densify` to the first entry in `model.transforms`.\n"
)
except ValueError:
# Re-raise ValueError (fatal configuration errors)
raise
except Exception:
# Silently skip validation if config structure is unexpected
pass
return warnings_list
[docs]
def lightning_cli_factory(
model_class_path: str,
link_arguments: list[LinkArguments] | None = None,
trainer_defaults: dict[str, Any] | None = None,
) -> type[LightningCLI]:
"""
Factory function for creating a :class:`LightningCLI` with a preset model and custom argument linking.
Example::
cli = lightning_cli_factory(
"cellarium.ml.models.IncrementalPCA",
link_arguments=[
LinkArguments(
("model.cpu_transforms", "model.transforms", "data"),
"model.model.init_args.var_names_g",
compute_var_names_g,
)
],
trainer_defaults={
"max_epochs": 1, # one pass
"strategy": {
"class_path": "lightning.pytorch.strategies.DDPStrategy",
"dict_kwargs": {"broadcast_buffers": False},
},
},
)
Args:
model_class_path:
A string representation of the model class path (e.g., ``"cellarium.ml.models.IncrementalPCA"``).
link_arguments:
A list of :class:`LinkArguments` that specify how to derive the value of a target
argument from the values of one or more source arguments. If ``None`` then no
arguments are linked.
trainer_defaults:
Default values for the trainer.
Returns:
A :class:`LightningCLI` class with the given model and argument linking.
"""
class NewLightningCLI(LightningCLI):
def __init__(self, args: ArgsType = None) -> None:
super().__init__(
CellariumModule,
CellariumAnnDataDataModule,
trainer_defaults=trainer_defaults,
args=args,
)
def _add_instantiators(self) -> None:
# disable breaking dependency injection support change introduced in PyTorch Lightning 2.3
# https://github.com/Lightning-AI/pytorch-lightning/pull/18105
pass
def before_instantiate_classes(self):
# issue a UserWarning if the subcommand is predict and return_predictions is not set to False
if self.subcommand == "predict":
return_predictions: bool = self.config["predict"]["return_predictions"]
if return_predictions:
warnings.warn(
"The `return_predictions` argument should be set to 'false' when running predict to avoid OOM. "
"This can be set at indent level 0 in the config file. Example:\n"
"model: ...\ndata: ...\ntrainer: ...\nreturn_predictions: false",
UserWarning,
)
# Validate sparse data path configuration (pass subcommand-scoped config)
subcommand_config = self.config.get(cast(str, self.subcommand), {})
sparse_config_warnings = _validate_sparse_config(subcommand_config)
for warning_msg in sparse_config_warnings:
warnings.warn(warning_msg, UserWarning)
return super().before_instantiate_classes()
def instantiate_classes(self) -> None:
with torch.device("meta"):
# skip the initialization of model parameters
# parameters are later initialized by the `CellariumModule.configure_model` method
return super().instantiate_classes()
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
if link_arguments is not None:
for link in link_arguments:
parser.link_arguments(link.source, link.target, link.compute_fn, link.apply_on)
# this is helpful for generating a default config file with --print_config
parser.set_defaults(
{
"model.model": model_class_path,
"data.dadc": "cellarium.ml.data.DistributedAnnDataCollection",
}
)
return NewLightningCLI
[docs]
@register_model
def cellarium_gpt(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.CellariumGPT` model.
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory("cellarium.ml.models.CellariumGPT")
cli(args=args)
[docs]
@register_model
def hvg_seurat_v3(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.HVGSeuratV3` model.
Computes highly variable genes using the Seurat v3 method over two Lightning
epochs. Raw-count data (no log-normalisation) is expected.
The number of batches (``n_batch``) is inferred automatically from the
``batch_index_n`` batch key in ``data.batch_keys``::
batch_index_n:
attr: obs
key: <your_batch_obs_column>
convert_fn: cellarium.ml.utilities.data.get_categories
Example run::
cellarium-ml hvg_seurat_v3 fit \
--model.model.init_args.n_top_genes 2000 \
--model.model.init_args.batch_key batch_index_n \
--model.model.init_args.output_path hvg.csv \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/counts_{0..3}.h5ad" \
--data.shard_size 10000 \
--data.max_cache_size 2 \
--data.batch_size 512 \
--data.num_workers 4 \
--trainer.accelerator gpu \
--trainer.devices 1
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.HVGSeuratV3",
link_arguments=[
LinkArguments(
("model.cpu_transforms", "model.transforms", "data"),
"model.model.init_args.var_names_g",
compute_var_names_g,
),
LinkArguments("data", "model.model.init_args.n_batch", compute_batch_index_n_categories),
],
trainer_defaults={
"max_epochs": 2,
"strategy": {
"class_path": "lightning.pytorch.strategies.DDPStrategy",
"dict_kwargs": {"broadcast_buffers": False},
},
},
)
cli(args=args)
[docs]
@register_model
def incremental_pca(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.IncrementalPCA` model.
This example shows how to fit feature count data to incremental PCA
model [1, 2].
Example run::
cellarium-ml incremental_pca fit \
--model.model.init_args.n_components 50 \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
--data.shard_size 100 \
--data.max_cache_size 2 \
--data.batch_size 100 \
--data.num_workers 4 \
--trainer.accelerator gpu \
--trainer.devices 1 \
--trainer.default_root_dir runs/ipca \
**References:**
1. `A Distributed and Incremental SVD Algorithm for Agglomerative Data Analysis on Large Networks (Iwen et al.)
<https://users.math.msu.edu/users/iwenmark/Papers/distrib_inc_svd.pdf>`_.
2. `Incremental Learning for Robust Visual Tracking (Ross et al.)
<https://www.cs.toronto.edu/~dross/ivt/RossLimLinYang_ijcv.pdf>`_.
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.IncrementalPCA",
link_arguments=[
LinkArguments(
("model.cpu_transforms", "model.transforms", "data"),
"model.model.init_args.var_names_g",
compute_var_names_g,
)
],
trainer_defaults={
"max_epochs": 1, # one pass
"strategy": {
"class_path": "lightning.pytorch.strategies.DDPStrategy",
"dict_kwargs": {"broadcast_buffers": False},
},
},
)
cli(args=args)
[docs]
@register_model
def logistic_regression(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.LogisticRegression` model.
Example run::
cellarium-ml logistic_regression fit \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
--data.shard_size 100 \
--data.max_cache_size 2 \
--data.batch_keys.x_ng.attr X \
--data.batch_keys.x_ng.convert_fn cellarium.ml.utilities.data.densify \
--data.batch_keys.var_names_g.attr var_names \
--data.batch_keys.y_n.attr obs \
--data.batch_keys.y_n.key cell_type \
--data.batch_keys.y_n.convert_fn cellarium.ml.utilities.data.categories_to_codes \
--data.batch_size 100 \
--data.num_workers 4 \
--trainer.accelerator gpu \
--trainer.devices 1 \
--trainer.max_steps 1000
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.LogisticRegression",
link_arguments=[
LinkArguments(
("model.cpu_transforms", "model.transforms", "data"),
"model.model.init_args.var_names_g",
compute_var_names_g,
),
LinkArguments("data", "model.model.init_args.n_obs", compute_n_obs),
LinkArguments("data", "model.model.init_args.y_categories", compute_y_categories),
],
)
cli(args=args)
[docs]
@register_model
def onepass_mean_var_std(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.OnePassMeanVarStd` model.
This example shows how to calculate mean, variance, and standard deviation of log normalized
feature count data in one pass [1].
Example run::
cellarium-ml onepass_mean_var_std fit \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
--data.shard_size 100 \
--data.max_cache_size 2 \
--data.batch_size 100 \
--data.num_workers 4 \
--trainer.accelerator gpu \
--trainer.devices 1 \
--trainer.default_root_dir runs/onepass \
**References:**
1. `Algorithms for calculating variance
<https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance>`_.
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.OnePassMeanVarStd",
link_arguments=[
LinkArguments(
("model.cpu_transforms", "model.transforms", "data"),
"model.model.init_args.var_names_g",
compute_var_names_g,
)
],
trainer_defaults={
"max_epochs": 1, # one pass
"strategy": {
"class_path": "lightning.pytorch.strategies.DDPStrategy",
"dict_kwargs": {"broadcast_buffers": False},
},
},
)
cli(args=args)
[docs]
@register_model
def probabilistic_pca(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.ProbabilisticPCA` model.
This example shows how to fit feature count data to probabilistic PCA
model [1].
There are two flavors of probabilistic PCA model that are available:
1. ``marginalized`` - latent variable ``z`` is marginalized out [1]. Marginalized
model provides a closed-form solution for the marginal log-likelihood.
Closed-form solution for the marginal log-likelihood has reduced
variance compared to the ``linear_vae`` model.
2. ``linear_vae`` - latent variable ``z`` has a diagonal Gaussian distribution [2].
Training a linear VAE with variational inference recovers a uniquely identifiable
global maximum corresponding to the principal component directions.
The global maximum of the ELBO objective for the linear VAE is identical
to the global maximum for the marginal log-likelihood of probabilistic PCA.
Example run::
cellarium-ml probabilistic_pca fit \
--model.model.init_args.n_components 256 \
--model.model.init_args.ppca_flavor marginalized \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
--data.shard_size 100 \
--data.max_cache_size 2 \
--data.batch_size 100 \
--data.shuffle true \
--data.num_workers 4 \
--trainer.accelerator gpu \
--trainer.devices 1 \
--trainer.max_steps 1000 \
--trainer.default_root_dir runs/ppca \
**References:**
1. `Probabilistic Principal Component Analysis (Tipping et al.)
<https://www.robots.ox.ac.uk/~cvrg/hilary2006/ppca.pdf>`_.
2. `Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.)
<https://openreview.net/pdf?id=r1xaVLUYuE>`_.
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.ProbabilisticPCA",
link_arguments=[
LinkArguments(
("model.cpu_transforms", "model.transforms", "data"),
"model.model.init_args.var_names_g",
compute_var_names_g,
),
LinkArguments("data", "model.model.init_args.n_obs", compute_n_obs),
],
)
cli(args=args)
[docs]
@register_model
def tdigest(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.TDigest` model.
This example shows how to calculate non-zero median of normalized feature count
data in one pass [1].
Example run::
cellarium-ml tdigest fit \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
--data.shard_size 100 \
--data.max_cache_size 2 \
--data.batch_size 100 \
--data.num_workers 4 \
--trainer.accelerator cpu \
--trainer.devices 1 \
--trainer.default_root_dir runs/tdigest \
**References:**
1. `Computing Extremely Accurate Quantiles Using T-Digests (Dunning et al.)
<https://github.com/tdunning/t-digest/blob/master/docs/t-digest-paper/histo.pdf>`_.
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.TDigest",
link_arguments=[
LinkArguments(
("model.cpu_transforms", "model.transforms", "data"),
"model.model.init_args.var_names_g",
compute_var_names_g,
)
],
trainer_defaults={
"max_epochs": 1, # one pass
},
)
cli(args=args)
[docs]
@register_model
def contrastive_mlp(args: ArgsType = None) -> None:
r"""
CLI to run the :class:`cellarium.ml.models.ContrastiveMLP` model.
This example shows how to perform contrastive learning with a default augmentation
strategy for omics data.
Example run::
cellarium-ml contrastive_mlp fit \
--model.model.init_args.hidden_size 4096 2048 1024 512 \
--model.model.init_args.embed_dim 256 \
--model.model.init_args.temperature 1.0 \
--model.model.init_args.target_count 10000 \
--data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
--data.shard_size 100 \
--data.max_cache_size 2 \
--data.batch_size 100 \
--data.num_workers 4 \
--trainer.accelerator gpu \
--trainer.devices 1 \
--trainer.default_root_dir runs/contrastive \
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
"""
cli = lightning_cli_factory(
"cellarium.ml.models.ContrastiveMLP",
link_arguments=[
LinkArguments("data", "model.model.init_args.n_obs", compute_n_vars),
],
trainer_defaults={
"max_epochs": 20,
},
)
cli(args=args)
[docs]
def main(args: ArgsType = None) -> None:
"""
CLI that dispatches to the appropriate model cli based on the model name in ``args`` and runs it.
Args:
args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``.
The model name is expected to be the first argument if ``args`` is a list
or the ``model_name`` key if ``args`` is a dictionary or ``Namespace``.
"""
if isinstance(args, (dict, Namespace)):
if "model_name" not in args:
raise ValueError("'model_name' key must be specified in args")
model_name = args.pop("model_name")
elif isinstance(args, list):
if len(args) == 0:
raise ValueError("'model_name' must be specified as the first argument in args")
model_name = args.pop(0)
elif args is None:
args = sys.argv[1:].copy()
if len(args) == 0:
raise ValueError("'model_name' must be specified after cellarium-ml")
model_name = args.pop(0)
if model_name not in REGISTERED_MODELS:
raise ValueError(f"'model_name' must be one of {list(REGISTERED_MODELS.keys())}. Got '{model_name}'")
model_cli = REGISTERED_MODELS[model_name]
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Transforming to str index.")
warnings.filterwarnings("ignore", message="LightningCLI's args parameter is intended to run from within Python")
warnings.filterwarnings("ignore", message="Your `IterableDataset` has `__len__` defined.")
model_cli(args) # run the model
if __name__ == "__main__":
main()