CLI
Command line interface for Cellarium ML.
- class cellarium.ml.cli.FileLoader(file_path, loader_fn, attr, convert_fn)[source]
Bases:
object
A YAML constructor for loading a file and accessing its attributes.
Example:
model: cpu_transforms: - class_path: cellarium.ml.transforms.Filter init_args: filter_list: !FileLoader file_path: gs://dsp-cellarium-cas-public/test-data/filter_list.csv loader_fn: pandas.read_csv attr: index convert_fn: numpy.ndarray.tolist
- Parameters:
file_path (str) – The file path to load the object from.
loader_fn (Callable[[str], Any] | str) – A function to load the object from the file path.
attr (str | None) – An attribute to get from the loaded object. If
None
the loaded object is returned.convert_fn (Callable[[Any], Any] | str | None) – A function to convert the loaded object. If
None
the loaded object is returned.
- class cellarium.ml.cli.CheckpointLoader(file_path, attr, convert_fn)[source]
Bases:
FileLoader
A YAML constructor for loading a
CellariumModule
checkpoint and accessing its attributes.Example:
model: transorms: - class_path: cellarium.ml.transforms.DivideByScale init_args: scale_g: !CheckpointLoader file_path: gs://dsp-cellarium-cas-public/test-data/tdigest.ckpt attr: model.median_g convert_fn: null
- Parameters:
file_path (str) – The file path to load the object from.
attr (str | None) – An attribute to get from the loaded object. If
None
the loaded object is returned.convert_fn (Callable[[Any], Any] | str | None) – A function to convert the loaded object. If
None
the loaded object is returned.
- cellarium.ml.cli.file_loader_constructor(loader: SafeLoader, node: MappingNode) FileLoader [source]
Construct an object from a file.
- Parameters:
loader (SafeLoader)
node (MappingNode)
- Return type:
- cellarium.ml.cli.checkpoint_loader_constructor(loader: SafeLoader, node: MappingNode) CheckpointLoader [source]
Construct an object from a checkpoint.
- Parameters:
loader (SafeLoader)
node (MappingNode)
- Return type:
- cellarium.ml.cli.CellariumModuleLoadFromCheckpoint
- class cellarium.ml.cli.LinkArguments(source: str | tuple[str, ...], target: str, compute_fn: Callable | None = None, apply_on: str = 'instantiate')[source]
Bases:
object
Arguments for linking the value of a target argument to the values of one or more source arguments.
- Parameters:
source (str | tuple[str, ...]) – Key(s) from which the target value is derived.
target (str) – Key to where the value is set.
compute_fn (Callable | None) – Function to compute target value from source.
apply_on (str) – At what point to set target value,
"parse"
or"instantiate"
.
- cellarium.ml.cli.compute_n_obs(data: CellariumAnnDataDataModule) int [source]
Compute the number of observations in the data.
- Parameters:
data (CellariumAnnDataDataModule) – A
CellariumAnnDataDataModule
instance.- Returns:
The number of observations in the data.
- Return type:
int
- cellarium.ml.cli.compute_y_categories(data: CellariumAnnDataDataModule) ndarray [source]
Compute the categories in the target variable.
E.g. if the target variable is
obs["cell_type"]
then this function returns the categories inobs["cell_type"]
:>>> np.asarray(data.dadc[0].obs["cell_type"].cat.categories)
- Parameters:
data (CellariumAnnDataDataModule) – A
CellariumAnnDataDataModule
instance.- Returns:
The categories in the target variable.
- Return type:
ndarray
- cellarium.ml.cli.compute_var_names_g(cpu_transforms: list[Module] | None, transforms: list[Module] | None, data: CellariumAnnDataDataModule) ndarray [source]
Compute variable names from the data by applying the transforms.
- Parameters:
cpu_transforms (list[Module] | None) – A list of of CPU transforms applied by the dataloader.
transforms (list[Module] | None) – A list of transforms.
data (CellariumAnnDataDataModule) – A
CellariumAnnDataDataModule
instance.
- Returns:
The variable names.
- Return type:
ndarray
- cellarium.ml.cli.lightning_cli_factory(model_class_path: str, link_arguments: list[LinkArguments] | None = None, trainer_defaults: dict[str, Any] | None = None) type[LightningCLI] [source]
Factory function for creating a
LightningCLI
with a preset model and custom argument linking.Example:
cli = lightning_cli_factory( "cellarium.ml.models.IncrementalPCA", link_arguments=[ LinkArguments( ("model.cpu_transforms", "model.transforms", "data"), "model.model.init_args.var_names_g", compute_var_names_g, ) ], trainer_defaults={ "max_epochs": 1, # one pass "strategy": { "class_path": "lightning.pytorch.strategies.DDPStrategy", "dict_kwargs": {"broadcast_buffers": False}, }, }, )
- Parameters:
model_class_path (str) – A string representation of the model class path (e.g.,
"cellarium.ml.models.IncrementalPCA"
).link_arguments (list[LinkArguments] | None) – A list of
LinkArguments
that specify how to derive the value of a target argument from the values of one or more source arguments. IfNone
then no arguments are linked.trainer_defaults (dict[str, Any] | None) – Default values for the trainer.
- Returns:
A
LightningCLI
class with the given model and argument linking.- Return type:
type[LightningCLI]
- cellarium.ml.cli.geneformer(args: List[str] | Dict[str, Any] | Namespace | None = None) None [source]
CLI to run the
cellarium.ml.models.Geneformer
model.This example shows how to fit feature count data to the Geneformer model [1].
Example run:
cellarium-ml geneformer fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 5 \ --data.num_workers 1 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/geneformer \ --trainer.max_steps 10
References:
- Parameters:
args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If
None
the arguments are taken fromsys.argv
.- Return type:
None
- cellarium.ml.cli.incremental_pca(args: List[str] | Dict[str, Any] | Namespace | None = None) None [source]
CLI to run the
cellarium.ml.models.IncrementalPCA
model.This example shows how to fit feature count data to incremental PCA model [1, 2].
Example run:
cellarium-ml incremental_pca fit \ --model.model.init_args.n_components 50 \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/ipca \
References:
- Parameters:
args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If
None
the arguments are taken fromsys.argv
.- Return type:
None
- cellarium.ml.cli.logistic_regression(args: List[str] | Dict[str, Any] | Namespace | None = None) None [source]
CLI to run the
cellarium.ml.models.LogisticRegression
model.Example run:
cellarium-ml logistic_regression fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_keys.x_ng.attr X \ --data.batch_keys.x_ng.convert_fn cellarium.ml.utilities.data.densify \ --data.batch_keys.var_names_g.attr var_names \ --data.batch_keys.y_n.attr obs \ --data.batch_keys.y_n.key cell_type \ --data.batch_keys.y_n.convert_fn cellarium.ml.utilities.data.categories_to_codes \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.max_steps 1000
- Parameters:
args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If
None
the arguments are taken fromsys.argv
.- Return type:
None
- cellarium.ml.cli.onepass_mean_var_std(args: List[str] | Dict[str, Any] | Namespace | None = None) None [source]
CLI to run the
cellarium.ml.models.OnePassMeanVarStd
model.This example shows how to calculate mean, variance, and standard deviation of log normalized feature count data in one pass [1].
Example run:
cellarium-ml onepass_mean_var_std fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/onepass \
References:
- Parameters:
args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If
None
the arguments are taken fromsys.argv
.- Return type:
None
- cellarium.ml.cli.probabilistic_pca(args: List[str] | Dict[str, Any] | Namespace | None = None) None [source]
CLI to run the
cellarium.ml.models.ProbabilisticPCA
model.This example shows how to fit feature count data to probabilistic PCA model [1].
There are two flavors of probabilistic PCA model that are available:
marginalized
- latent variablez
is marginalized out [1]. Marginalized model provides a closed-form solution for the marginal log-likelihood. Closed-form solution for the marginal log-likelihood has reduced variance compared to thelinear_vae
model.linear_vae
- latent variablez
has a diagonal Gaussian distribution [2]. Training a linear VAE with variational inference recovers a uniquely identifiable global maximum corresponding to the principal component directions. The global maximum of the ELBO objective for the linear VAE is identical to the global maximum for the marginal log-likelihood of probabilistic PCA.
Example run:
cellarium-ml probabilistic_pca fit \ --model.model.init_args.n_components 256 \ --model.model.init_args.ppca_flavor marginalized \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.shuffle true \ --data.num_workers 4 \ --trainer.accelerator gpu \ --trainer.devices 1 \ --trainer.max_steps 1000 \ --trainer.default_root_dir runs/ppca \
References:
Probabilistic Principal Component Analysis (Tipping et al.).
Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.).
- Parameters:
args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If
None
the arguments are taken fromsys.argv
.- Return type:
None
- cellarium.ml.cli.tdigest(args: List[str] | Dict[str, Any] | Namespace | None = None) None [source]
CLI to run the
cellarium.ml.models.TDigest
model.This example shows how to calculate non-zero median of normalized feature count data in one pass [1].
Example run:
cellarium-ml tdigest fit \ --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \ --data.shard_size 100 \ --data.max_cache_size 2 \ --data.batch_size 100 \ --data.num_workers 4 \ --trainer.accelerator cpu \ --trainer.devices 1 \ --trainer.default_root_dir runs/tdigest \
References:
- Parameters:
args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If
None
the arguments are taken fromsys.argv
.- Return type:
None
- cellarium.ml.cli.main(args: List[str] | Dict[str, Any] | Namespace | None = None) None [source]
CLI that dispatches to the appropriate model cli based on the model name in
args
and runs it.- Parameters:
args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If
None
the arguments are taken fromsys.argv
. The model name is expected to be the first argument ifargs
is a list or themodel_name
key ifargs
is a dictionary orNamespace
.- Return type:
None
- class cellarium.ml.cli.LightningModule__load_from_checkpoint_class(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any)
Bases:
CellariumModule
,ClassFromFunctionBase
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__
in the checkpoint under"hyper_parameters"
.Any arguments specified through **kwargs will override args stored in
"hyper_parameters"
.- Parameters:
checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object
map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in
torch.load()
.hparams_file (str | Path | None) –
Optional path to a
.yaml
or.csv
file with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yaml
file with the hparams you’d like to use. These will be converted into adict
and passed into yourLightningModule
for use.If your model’s
hparams
argument isNamespace
and.yaml
file has hierarchical structure, you need to refactor your model to treathparams
asdict
.strict (bool | None) – Whether to strictly enforce that the keys in
checkpoint_path
match the keys returned by this module’s state dict. Defaults toTrue
unlessLightningModule.strict_loading
is set, in which case it defaults to the value ofLightningModule.strict_loading
.**kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns:
LightningModule
instance with loaded weights and hyperparameters (if available).- Return type:
Self
Note
load_from_checkpoint
is a class method. You should use yourLightningModule
class to call it instead of theLightningModule
instance, or aTypeError
will be raised.Note
To ensure all layers can be loaded from the checkpoint, this function will call
configure_model()
directly after instantiating the model if this hook is overridden in your LightningModule. However, note thatload_from_checkpoint
does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via.fit(ckpt_path=...)
.Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
- __new__(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any) Self
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__
in the checkpoint under"hyper_parameters"
.Any arguments specified through **kwargs will override args stored in
"hyper_parameters"
.- Parameters:
checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object
map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in
torch.load()
.hparams_file (str | Path | None) –
Optional path to a
.yaml
or.csv
file with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yaml
file with the hparams you’d like to use. These will be converted into adict
and passed into yourLightningModule
for use.If your model’s
hparams
argument isNamespace
and.yaml
file has hierarchical structure, you need to refactor your model to treathparams
asdict
.strict (bool | None) – Whether to strictly enforce that the keys in
checkpoint_path
match the keys returned by this module’s state dict. Defaults toTrue
unlessLightningModule.strict_loading
is set, in which case it defaults to the value ofLightningModule.strict_loading
.**kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns:
LightningModule
instance with loaded weights and hyperparameters (if available).- Return type:
Self
Note
load_from_checkpoint
is a class method. You should use yourLightningModule
class to call it instead of theLightningModule
instance, or aTypeError
will be raised.Note
To ensure all layers can be loaded from the checkpoint, this function will call
configure_model()
directly after instantiating the model if this hook is overridden in your LightningModule. However, note thatload_from_checkpoint
does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via.fit(ckpt_path=...)
.Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
- wrapped_function(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any) Self
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__
in the checkpoint under"hyper_parameters"
.Any arguments specified through **kwargs will override args stored in
"hyper_parameters"
.- Parameters:
checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object
map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in
torch.load()
.hparams_file (str | Path | None) –
Optional path to a
.yaml
or.csv
file with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yaml
file with the hparams you’d like to use. These will be converted into adict
and passed into yourLightningModule
for use.If your model’s
hparams
argument isNamespace
and.yaml
file has hierarchical structure, you need to refactor your model to treathparams
asdict
.strict (bool | None) – Whether to strictly enforce that the keys in
checkpoint_path
match the keys returned by this module’s state dict. Defaults toTrue
unlessLightningModule.strict_loading
is set, in which case it defaults to the value ofLightningModule.strict_loading
.**kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns:
LightningModule
instance with loaded weights and hyperparameters (if available).- Return type:
Self
Note
load_from_checkpoint
is a class method. You should use yourLightningModule
class to call it instead of theLightningModule
instance, or aTypeError
will be raised.Note
To ensure all layers can be loaded from the checkpoint, this function will call
configure_model()
directly after instantiating the model if this hook is overridden in your LightningModule. However, note thatload_from_checkpoint
does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via.fit(ckpt_path=...)
.Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)