CLI
Command line interface for Cellarium ML.
- class cellarium.ml.cli.FileLoader(file_path, loader_fn, attr, convert_fn)[source]
Bases:
objectA YAML constructor for loading a file and accessing its attributes.
Example:
model: cpu_transforms: - class_path: cellarium.ml.transforms.Filter init_args: filter_list: !FileLoader file_path: gs://dsp-cellarium-cas-public/test-data/filter_list.csv loader_fn: pandas.read_csv attr: index convert_fn: numpy.ndarray.tolist
- Parameters:
file_path (str) – The file path to load the object from.
loader_fn (Callable[[str], Any] | str) – A function to load the object from the file path.
attr (str | None) – An attribute to get from the loaded object. If
Nonethe loaded object is returned.convert_fn (Callable[[Any], Any] | str | None) – A function to convert the loaded object. If
Nonethe loaded object is returned.
- class cellarium.ml.cli.CheckpointLoader(file_path, attr, convert_fn)[source]
Bases:
FileLoaderA YAML constructor for loading a
CellariumModulecheckpoint and accessing its attributes.Example:
model: transorms: - class_path: cellarium.ml.transforms.DivideByScale init_args: scale_g: !CheckpointLoader file_path: gs://dsp-cellarium-cas-public/test-data/tdigest.ckpt attr: model.median_g convert_fn: null
- Parameters:
file_path (str) – The file path to load the object from.
attr (str | None) – An attribute to get from the loaded object. If
Nonethe loaded object is returned.convert_fn (Callable[[Any], Any] | str | None) – A function to convert the loaded object. If
Nonethe loaded object is returned.
- cellarium.ml.cli.file_loader_constructor(loader: SafeLoader, node: MappingNode) FileLoader[source]
Construct an object from a file.
- Parameters:
loader (SafeLoader)
node (MappingNode)
- Return type:
- cellarium.ml.cli.checkpoint_loader_constructor(loader: SafeLoader, node: MappingNode) CheckpointLoader[source]
Construct an object from a checkpoint.
- Parameters:
loader (SafeLoader)
node (MappingNode)
- Return type:
- cellarium.ml.cli.CellariumModuleLoadFromCheckpoint
- class cellarium.ml.cli.LinkArguments(source: str | tuple[str, ...], target: str, compute_fn: Callable | None = None, apply_on: str = 'instantiate')[source]
Bases:
objectArguments for linking the value of a target argument to the values of one or more source arguments.
- Parameters:
source (str | tuple[str, ...]) – Key(s) from which the target value is derived.
target (str) – Key to where the value is set.
compute_fn (Callable | None) – Function to compute target value from source.
apply_on (str) – At what point to set target value,
"parse"or"instantiate".
- cellarium.ml.cli.compute_n_obs(data: CellariumAnnDataDataModule) int[source]
Compute the number of observations in the data.
- Parameters:
data (CellariumAnnDataDataModule) – A
CellariumAnnDataDataModuleinstance.- Returns:
The number of observations in the data.
- Return type:
int
- cellarium.ml.cli.compute_y_categories(data: CellariumAnnDataDataModule) ndarray[source]
Compute the categories in the target variable.
E.g. if the target variable is
obs["cell_type"]then this function returns the categories inobs["cell_type"]:>>> np.asarray(data.dadc[0].obs["cell_type"].cat.categories)
- Parameters:
data (CellariumAnnDataDataModule) – A
CellariumAnnDataDataModuleinstance.- Returns:
The categories in the target variable.
- Return type:
ndarray
- cellarium.ml.cli.compute_var_names_g(cpu_transforms: list[Module] | None, transforms: list[Module] | None, data: CellariumAnnDataDataModule) ndarray[source]
Compute variable names from the data by applying the transforms.
- Parameters:
cpu_transforms (list[Module] | None) – A list of of CPU transforms applied by the dataloader.
transforms (list[Module] | None) – A list of transforms.
data (CellariumAnnDataDataModule) – A
CellariumAnnDataDataModuleinstance.
- Returns:
The variable names.
- Return type:
ndarray
- cellarium.ml.cli.lightning_cli_factory(model_class_path: str, link_arguments: list[LinkArguments] | None = None, trainer_defaults: dict[str, Any] | None = None) type[LightningCLI][source]
Factory function for creating a
LightningCLIwith a preset model and custom argument linking.Example:
cli = lightning_cli_factory( "cellarium.ml.models.IncrementalPCA", link_arguments=[ LinkArguments( ("model.cpu_transforms", "model.transforms", "data"), "model.model.init_args.var_names_g", compute_var_names_g, ) ], trainer_defaults={ "max_epochs": 1, # one pass "strategy": { "class_path": "lightning.pytorch.strategies.DDPStrategy", "dict_kwargs": {"broadcast_buffers": False}, }, }, )
- Parameters:
model_class_path (str) – A string representation of the model class path (e.g.,
"cellarium.ml.models.IncrementalPCA").link_arguments (list[LinkArguments] | None) – A list of
LinkArgumentsthat specify how to derive the value of a target argument from the values of one or more source arguments. IfNonethen no arguments are linked.trainer_defaults (dict[str, Any] | None) – Default values for the trainer.
- Returns:
A
LightningCLIclass with the given model and argument linking.- Return type:
type[LightningCLI]
- cellarium.ml.cli.cellarium_gpt(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI to run the
cellarium.ml.models.CellariumGPTmodel.- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv.- Return type:
None
- cellarium.ml.cli.geneformer(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI to run the
cellarium.ml.models.Geneformermodel.This example shows how to fit feature count data to the Geneformer model [1].
Example run:
cellarium-ml geneformer fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 5 \ --data.num_workers 1 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/geneformer \ --trainer.max_steps 10
References:
- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv.- Return type:
None
- cellarium.ml.cli.incremental_pca(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI to run the
cellarium.ml.models.IncrementalPCAmodel.This example shows how to fit feature count data to incremental PCA model [1, 2].
Example run:
cellarium-ml incremental_pca fit \ --model.model.init_args.n_components 50 \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/ipca \
References:
- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv.- Return type:
None
- cellarium.ml.cli.logistic_regression(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI to run the
cellarium.ml.models.LogisticRegressionmodel.Example run:
cellarium-ml logistic_regression fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_keys.x_ng.attr X \ --data.batch_keys.x_ng.convert_fn cellarium.ml.utilities.data.densify \ --data.batch_keys.var_names_g.attr var_names \ --data.batch_keys.y_n.attr obs \ --data.batch_keys.y_n.key cell_type \ --data.batch_keys.y_n.convert_fn cellarium.ml.utilities.data.categories_to_codes \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.max_steps 1000
- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv.- Return type:
None
- cellarium.ml.cli.onepass_mean_var_std(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI to run the
cellarium.ml.models.OnePassMeanVarStdmodel.This example shows how to calculate mean, variance, and standard deviation of log normalized feature count data in one pass [1].
Example run:
cellarium-ml onepass_mean_var_std fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/onepass \
References:
- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv.- Return type:
None
- cellarium.ml.cli.probabilistic_pca(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI to run the
cellarium.ml.models.ProbabilisticPCAmodel.This example shows how to fit feature count data to probabilistic PCA model [1].
There are two flavors of probabilistic PCA model that are available:
marginalized- latent variablezis marginalized out [1]. Marginalized model provides a closed-form solution for the marginal log-likelihood. Closed-form solution for the marginal log-likelihood has reduced variance compared to thelinear_vaemodel.linear_vae- latent variablezhas a diagonal Gaussian distribution [2]. Training a linear VAE with variational inference recovers a uniquely identifiable global maximum corresponding to the principal component directions. The global maximum of the ELBO objective for the linear VAE is identical to the global maximum for the marginal log-likelihood of probabilistic PCA.
Example run:
cellarium-ml probabilistic_pca fit \ --model.model.init_args.n_components 256 \ --model.model.init_args.ppca_flavor marginalized \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.shuffle true \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.max_steps 1000 \ --trainer.default_root_dir runs/ppca \
References:
Probabilistic Principal Component Analysis (Tipping et al.).
Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.).
- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv.- Return type:
None
- cellarium.ml.cli.tdigest(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI to run the
cellarium.ml.models.TDigestmodel.This example shows how to calculate non-zero median of normalized feature count data in one pass [1].
Example run:
cellarium-ml tdigest fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator cpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/tdigest \
References:
- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv.- Return type:
None
- cellarium.ml.cli.main(args: list[str] | dict[str, Any] | Namespace | None = None) None[source]
CLI that dispatches to the appropriate model cli based on the model name in
argsand runs it.- Parameters:
args (list[str] | dict[str, Any] | Namespace | None) – Arguments to parse. If
Nonethe arguments are taken fromsys.argv. The model name is expected to be the first argument ifargsis a list or themodel_namekey ifargsis a dictionary orNamespace.- Return type:
None
- class cellarium.ml.cli.LightningModule__load_from_checkpoint_class(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any)
Bases:
CellariumModule,ClassFromFunctionBasePrimary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__in the checkpoint under"hyper_parameters".Any arguments specified through **kwargs will override args stored in
"hyper_parameters".- Parameters:
checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object
map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in
torch.load().hparams_file (str | Path | None) –
Optional path to a
.yamlor.csvfile with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yamlfile with the hparams you’d like to use. These will be converted into adictand passed into yourLightningModulefor use.If your model’s
hparamsargument isNamespaceand.yamlfile has hierarchical structure, you need to refactor your model to treathparamsasdict.strict (bool | None) – Whether to strictly enforce that the keys in
checkpoint_pathmatch the keys returned by this module’s state dict. Defaults toTrueunlessLightningModule.strict_loadingis set, in which case it defaults to the value ofLightningModule.strict_loading.**kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns:
LightningModuleinstance with loaded weights and hyperparameters (if available).- Return type:
Self
Note
load_from_checkpointis a class method. You should use yourLightningModuleclass to call it instead of theLightningModuleinstance, or aTypeErrorwill be raised.Note
To ensure all layers can be loaded from the checkpoint, this function will call
configure_model()directly after instantiating the model if this hook is overridden in your LightningModule. However, note thatload_from_checkpointdoes not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via.fit(ckpt_path=...).Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
- __new__(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any) Self
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__in the checkpoint under"hyper_parameters".Any arguments specified through **kwargs will override args stored in
"hyper_parameters".- Parameters:
checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object
map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in
torch.load().hparams_file (str | Path | None) –
Optional path to a
.yamlor.csvfile with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yamlfile with the hparams you’d like to use. These will be converted into adictand passed into yourLightningModulefor use.If your model’s
hparamsargument isNamespaceand.yamlfile has hierarchical structure, you need to refactor your model to treathparamsasdict.strict (bool | None) – Whether to strictly enforce that the keys in
checkpoint_pathmatch the keys returned by this module’s state dict. Defaults toTrueunlessLightningModule.strict_loadingis set, in which case it defaults to the value ofLightningModule.strict_loading.**kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns:
LightningModuleinstance with loaded weights and hyperparameters (if available).- Return type:
Self
Note
load_from_checkpointis a class method. You should use yourLightningModuleclass to call it instead of theLightningModuleinstance, or aTypeErrorwill be raised.Note
To ensure all layers can be loaded from the checkpoint, this function will call
configure_model()directly after instantiating the model if this hook is overridden in your LightningModule. However, note thatload_from_checkpointdoes not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via.fit(ckpt_path=...).Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
- wrapped_function(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any) Self
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__in the checkpoint under"hyper_parameters".Any arguments specified through **kwargs will override args stored in
"hyper_parameters".- Parameters:
checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object
map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in
torch.load().hparams_file (str | Path | None) –
Optional path to a
.yamlor.csvfile with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yamlfile with the hparams you’d like to use. These will be converted into adictand passed into yourLightningModulefor use.If your model’s
hparamsargument isNamespaceand.yamlfile has hierarchical structure, you need to refactor your model to treathparamsasdict.strict (bool | None) – Whether to strictly enforce that the keys in
checkpoint_pathmatch the keys returned by this module’s state dict. Defaults toTrueunlessLightningModule.strict_loadingis set, in which case it defaults to the value ofLightningModule.strict_loading.**kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns:
LightningModuleinstance with loaded weights and hyperparameters (if available).- Return type:
Self
Note
load_from_checkpointis a class method. You should use yourLightningModuleclass to call it instead of theLightningModuleinstance, or aTypeErrorwill be raised.Note
To ensure all layers can be loaded from the checkpoint, this function will call
configure_model()directly after instantiating the model if this hook is overridden in your LightningModule. However, note thatload_from_checkpointdoes not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via.fit(ckpt_path=...).Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)