CLI

Command line interface for Cellarium ML.

class cellarium.ml.cli.FileLoader(file_path, loader_fn, attr, convert_fn)[source]

Bases: object

A YAML constructor for loading a file and accessing its attributes.

Example:

model:
  cpu_transforms:
    - class_path: cellarium.ml.transforms.Filter
      init_args:
        filter_list:
          !FileLoader
          file_path: gs://dsp-cellarium-cas-public/test-data/filter_list.csv
          loader_fn: pandas.read_csv
          attr: index
          convert_fn: numpy.ndarray.tolist
Parameters:
  • file_path (str) – The file path to load the object from.

  • loader_fn (Callable[[str], Any] | str) – A function to load the object from the file path.

  • attr (str | None) – An attribute to get from the loaded object. If None the loaded object is returned.

  • convert_fn (Callable[[Any], Any] | str | None) – A function to convert the loaded object. If None the loaded object is returned.

class cellarium.ml.cli.CheckpointLoader(file_path, attr, convert_fn)[source]

Bases: FileLoader

A YAML constructor for loading a CellariumModule checkpoint and accessing its attributes.

Example:

model:
  transorms:
    - class_path: cellarium.ml.transforms.DivideByScale
      init_args:
        scale_g:
          !CheckpointLoader
          file_path: gs://dsp-cellarium-cas-public/test-data/tdigest.ckpt
          attr: model.median_g
          convert_fn: null
Parameters:
  • file_path (str) – The file path to load the object from.

  • attr (str | None) – An attribute to get from the loaded object. If None the loaded object is returned.

  • convert_fn (Callable[[Any], Any] | str | None) – A function to convert the loaded object. If None the loaded object is returned.

cellarium.ml.cli.file_loader_constructor(loader: SafeLoader, node: MappingNode) FileLoader[source]

Construct an object from a file.

Parameters:
  • loader (SafeLoader)

  • node (MappingNode)

Return type:

FileLoader

cellarium.ml.cli.checkpoint_loader_constructor(loader: SafeLoader, node: MappingNode) CheckpointLoader[source]

Construct an object from a checkpoint.

Parameters:
  • loader (SafeLoader)

  • node (MappingNode)

Return type:

CheckpointLoader

cellarium.ml.cli.CellariumModuleLoadFromCheckpoint

alias of LightningModule__load_from_checkpoint_class

class cellarium.ml.cli.LinkArguments(source: str | tuple[str, ...], target: str, compute_fn: Callable | None = None, apply_on: str = 'instantiate')[source]

Bases: object

Arguments for linking the value of a target argument to the values of one or more source arguments.

Parameters:
  • source (str | tuple[str, ...]) – Key(s) from which the target value is derived.

  • target (str) – Key to where the value is set.

  • compute_fn (Callable | None) – Function to compute target value from source.

  • apply_on (str) – At what point to set target value, "parse" or "instantiate".

cellarium.ml.cli.compute_n_obs(data: CellariumAnnDataDataModule) int[source]

Compute the number of observations in the data.

Parameters:

data (CellariumAnnDataDataModule) – A CellariumAnnDataDataModule instance.

Returns:

The number of observations in the data.

Return type:

int

cellarium.ml.cli.compute_y_categories(data: CellariumAnnDataDataModule) ndarray[source]

Compute the categories in the target variable.

E.g. if the target variable is obs["cell_type"] then this function returns the categories in obs["cell_type"]:

>>> np.asarray(data.dadc[0].obs["cell_type"].cat.categories)
Parameters:

data (CellariumAnnDataDataModule) – A CellariumAnnDataDataModule instance.

Returns:

The categories in the target variable.

Return type:

ndarray

cellarium.ml.cli.compute_var_names_g(cpu_transforms: list[Module] | None, transforms: list[Module] | None, data: CellariumAnnDataDataModule) ndarray[source]

Compute variable names from the data by applying the transforms.

Parameters:
  • cpu_transforms (list[Module] | None) – A list of of CPU transforms applied by the dataloader.

  • transforms (list[Module] | None) – A list of transforms.

  • data (CellariumAnnDataDataModule) – A CellariumAnnDataDataModule instance.

Returns:

The variable names.

Return type:

ndarray

cellarium.ml.cli.lightning_cli_factory(model_class_path: str, link_arguments: list[LinkArguments] | None = None, trainer_defaults: dict[str, Any] | None = None) type[LightningCLI][source]

Factory function for creating a LightningCLI with a preset model and custom argument linking.

Example:

cli = lightning_cli_factory(
    "cellarium.ml.models.IncrementalPCA",
    link_arguments=[
        LinkArguments(
            ("model.cpu_transforms", "model.transforms", "data"),
            "model.model.init_args.var_names_g",
            compute_var_names_g,
        )
    ],
    trainer_defaults={
        "max_epochs": 1,  # one pass
        "strategy": {
            "class_path": "lightning.pytorch.strategies.DDPStrategy",
            "dict_kwargs": {"broadcast_buffers": False},
        },
    },
)
Parameters:
  • model_class_path (str) – A string representation of the model class path (e.g., "cellarium.ml.models.IncrementalPCA").

  • link_arguments (list[LinkArguments] | None) – A list of LinkArguments that specify how to derive the value of a target argument from the values of one or more source arguments. If None then no arguments are linked.

  • trainer_defaults (dict[str, Any] | None) – Default values for the trainer.

Returns:

A LightningCLI class with the given model and argument linking.

Return type:

type[LightningCLI]

cellarium.ml.cli.geneformer(args: List[str] | Dict[str, Any] | Namespace | None = None) None[source]

CLI to run the cellarium.ml.models.Geneformer model.

This example shows how to fit feature count data to the Geneformer model [1].

Example run:

cellarium-ml geneformer fit \
    --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
    --data.shard_size 100 \
    --data.max_cache_size 2 \
    --data.batch_size 5 \
    --data.num_workers 1 \
    --trainer.accelerator gpu \
    --trainer.devices 1 \
    --trainer.default_root_dir runs/geneformer \
    --trainer.max_steps 10

References:

  1. Transfer learning enables predictions in network biology (Theodoris et al.).

Parameters:

args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If None the arguments are taken from sys.argv.

Return type:

None

cellarium.ml.cli.incremental_pca(args: List[str] | Dict[str, Any] | Namespace | None = None) None[source]

CLI to run the cellarium.ml.models.IncrementalPCA model.

This example shows how to fit feature count data to incremental PCA model [1, 2].

Example run:

cellarium-ml incremental_pca fit \
    --model.model.init_args.n_components 50 \
    --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
    --data.shard_size 100 \
    --data.max_cache_size 2 \
    --data.batch_size 100 \
    --data.num_workers 4 \
    --trainer.accelerator gpu \
    --trainer.devices 1 \
    --trainer.default_root_dir runs/ipca \

References:

  1. A Distributed and Incremental SVD Algorithm for Agglomerative Data Analysis on Large Networks (Iwen et al.).

  2. Incremental Learning for Robust Visual Tracking (Ross et al.).

Parameters:

args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If None the arguments are taken from sys.argv.

Return type:

None

cellarium.ml.cli.logistic_regression(args: List[str] | Dict[str, Any] | Namespace | None = None) None[source]

CLI to run the cellarium.ml.models.LogisticRegression model.

Example run:

cellarium-ml logistic_regression fit \
    --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
    --data.shard_size 100 \
    --data.max_cache_size 2 \
    --data.batch_keys.x_ng.attr X \
    --data.batch_keys.x_ng.convert_fn cellarium.ml.utilities.data.densify \
    --data.batch_keys.var_names_g.attr var_names \
    --data.batch_keys.y_n.attr obs \
    --data.batch_keys.y_n.key cell_type \
    --data.batch_keys.y_n.convert_fn cellarium.ml.utilities.data.categories_to_codes \
    --data.batch_size 100 \
    --data.num_workers 4 \
    --trainer.accelerator gpu \
    --trainer.devices 1 \
    --trainer.max_steps 1000
Parameters:

args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If None the arguments are taken from sys.argv.

Return type:

None

cellarium.ml.cli.onepass_mean_var_std(args: List[str] | Dict[str, Any] | Namespace | None = None) None[source]

CLI to run the cellarium.ml.models.OnePassMeanVarStd model.

This example shows how to calculate mean, variance, and standard deviation of log normalized feature count data in one pass [1].

Example run:

cellarium-ml onepass_mean_var_std fit \
    --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
    --data.shard_size 100 \
    --data.max_cache_size 2 \
    --data.batch_size 100 \
    --data.num_workers 4 \
    --trainer.accelerator gpu \
    --trainer.devices 1 \
    --trainer.default_root_dir runs/onepass \

References:

  1. Algorithms for calculating variance.

Parameters:

args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If None the arguments are taken from sys.argv.

Return type:

None

cellarium.ml.cli.probabilistic_pca(args: List[str] | Dict[str, Any] | Namespace | None = None) None[source]

CLI to run the cellarium.ml.models.ProbabilisticPCA model.

This example shows how to fit feature count data to probabilistic PCA model [1].

There are two flavors of probabilistic PCA model that are available:

  1. marginalized - latent variable z is marginalized out [1]. Marginalized model provides a closed-form solution for the marginal log-likelihood. Closed-form solution for the marginal log-likelihood has reduced variance compared to the linear_vae model.

  2. linear_vae - latent variable z has a diagonal Gaussian distribution [2]. Training a linear VAE with variational inference recovers a uniquely identifiable global maximum corresponding to the principal component directions. The global maximum of the ELBO objective for the linear VAE is identical to the global maximum for the marginal log-likelihood of probabilistic PCA.

Example run:

cellarium-ml probabilistic_pca fit \
    --model.model.init_args.n_components 256 \
    --model.model.init_args.ppca_flavor marginalized \
    --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
    --data.shard_size 100 \
    --data.max_cache_size 2 \
    --data.batch_size 100 \
    --data.shuffle true \
    --data.num_workers 4 \
    --trainer.accelerator gpu \
    --trainer.devices 1 \
    --trainer.max_steps 1000 \
    --trainer.default_root_dir runs/ppca \

References:

  1. Probabilistic Principal Component Analysis (Tipping et al.).

  2. Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.).

Parameters:

args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If None the arguments are taken from sys.argv.

Return type:

None

cellarium.ml.cli.tdigest(args: List[str] | Dict[str, Any] | Namespace | None = None) None[source]

CLI to run the cellarium.ml.models.TDigest model.

This example shows how to calculate non-zero median of normalized feature count data in one pass [1].

Example run:

cellarium-ml tdigest fit \
    --data.filenames "gs://dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad" \
    --data.shard_size 100 \
    --data.max_cache_size 2 \
    --data.batch_size 100 \
    --data.num_workers 4 \
    --trainer.accelerator cpu \
    --trainer.devices 1 \
    --trainer.default_root_dir runs/tdigest \

References:

  1. Computing Extremely Accurate Quantiles Using T-Digests (Dunning et al.).

Parameters:

args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If None the arguments are taken from sys.argv.

Return type:

None

cellarium.ml.cli.main(args: List[str] | Dict[str, Any] | Namespace | None = None) None[source]

CLI that dispatches to the appropriate model cli based on the model name in args and runs it.

Parameters:

args (List[str] | Dict[str, Any] | Namespace | None) – Arguments to parse. If None the arguments are taken from sys.argv. The model name is expected to be the first argument if args is a list or the model_name key if args is a dictionary or Namespace.

Return type:

None

class cellarium.ml.cli.LightningModule__load_from_checkpoint_class(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any)

Bases: CellariumModule, ClassFromFunctionBase

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "hyper_parameters".

Any arguments specified through **kwargs will override args stored in "hyper_parameters".

Parameters:
  • checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object

  • map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file (str | Path | None) –

    Optional path to a .yaml or .csv file with hierarchical structure as in this example:

    drop_prob: 0.2
    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.

    If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.

  • strict (bool | None) – Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module’s state dict. Defaults to True unless LightningModule.strict_loading is set, in which case it defaults to the value of LightningModule.strict_loading.

  • **kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Returns:

LightningModule instance with loaded weights and hyperparameters (if available).

Return type:

Self

Note

load_from_checkpoint is a class method. You should use your LightningModule class to call it instead of the LightningModule instance, or a TypeError will be raised.

Note

To ensure all layers can be loaded from the checkpoint, this function will call configure_model() directly after instantiating the model if this hook is overridden in your LightningModule. However, note that load_from_checkpoint does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via .fit(ckpt_path=...).

Example:

# load weights without mapping ...
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
model = MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path=NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
__new__(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any) Self

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "hyper_parameters".

Any arguments specified through **kwargs will override args stored in "hyper_parameters".

Parameters:
  • checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object

  • map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file (str | Path | None) –

    Optional path to a .yaml or .csv file with hierarchical structure as in this example:

    drop_prob: 0.2
    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.

    If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.

  • strict (bool | None) – Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module’s state dict. Defaults to True unless LightningModule.strict_loading is set, in which case it defaults to the value of LightningModule.strict_loading.

  • **kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Returns:

LightningModule instance with loaded weights and hyperparameters (if available).

Return type:

Self

Note

load_from_checkpoint is a class method. You should use your LightningModule class to call it instead of the LightningModule instance, or a TypeError will be raised.

Note

To ensure all layers can be loaded from the checkpoint, this function will call configure_model() directly after instantiating the model if this hook is overridden in your LightningModule. However, note that load_from_checkpoint does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via .fit(ckpt_path=...).

Example:

# load weights without mapping ...
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
model = MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path=NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
wrapped_function(checkpoint_path: str | Path | IO, map_location: device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any) Self

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "hyper_parameters".

Any arguments specified through **kwargs will override args stored in "hyper_parameters".

Parameters:
  • checkpoint_path (str | Path | IO) – Path to checkpoint. This can also be a URL, or file-like object

  • map_location (device | str | int | Callable[[UntypedStorage, str], UntypedStorage | None] | Dict[device | str | int, device | str | int] | None) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file (str | Path | None) –

    Optional path to a .yaml or .csv file with hierarchical structure as in this example:

    drop_prob: 0.2
    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.

    If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.

  • strict (bool | None) – Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module’s state dict. Defaults to True unless LightningModule.strict_loading is set, in which case it defaults to the value of LightningModule.strict_loading.

  • **kwargs (Any) – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Returns:

LightningModule instance with loaded weights and hyperparameters (if available).

Return type:

Self

Note

load_from_checkpoint is a class method. You should use your LightningModule class to call it instead of the LightningModule instance, or a TypeError will be raised.

Note

To ensure all layers can be loaded from the checkpoint, this function will call configure_model() directly after instantiating the model if this hook is overridden in your LightningModule. However, note that load_from_checkpoint does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via .fit(ckpt_path=...).

Example:

# load weights without mapping ...
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
model = MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path=NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)