Example CLI workflow

This example workflow demonstrates how to:

  1. Compute the mean and standard deviation of normalized and log1p-transformed data.

  2. Perform PCA on data that has been z-scored based on the mean and standard deviation from step 1.

  3. Train a logistic regression classifier on the principal component embeddings of the data.

To execute the workflow, download the config files from here and run the scripts in sequence:

cellarium-ml onepass_mean_var_std fit --config onepass_train_config.yaml
cellarium-ml incremental_pca fit --config ipca_train_config.yaml
cellarium-ml logistic_regression fit --config lr_train_config.yaml
cellarium-ml logistic_regression fit --config lr_resume_train_config.yaml

Below we explain how the config files were created and what changes were made to the default configuration.

1. OnePassMeanVarStd

Generate a default config file:

cellarium-ml onepass_mean_var_std fit --print_config > onepass_train_config.yaml

Below we highlight the changes made to the default configuration file.

trainer

Change the number of devices and set the path for logs and weights:

<   devices: auto
<   default_root_dir: null
---
>   devices: 2
>   default_root_dir: /tmp/test_examples/onepass

model

Add NormalizeTotal and Log1p transforms:

<   transforms: null
---
>   transforms:
>     - class_path: cellarium.ml.transforms.NormalizeTotal
>       init_args:
>         target_count: 10_000
>     - cellarium.ml.transforms.Log1p

Change OnePassMeanVarStd’s algorithm to shifted_data:

<       algorithm: naive
---
>       algorithm: shifted_data

data

Configure the DistributedAnnDataCollection. Here we validate obs columns that are used by the transforms and the model (total_mrna_umis). Validation is done for each loaded AnnData file against the first (reference) AnnData file by checking that column names and dtypes match between the two:

<       filenames: null
<       shard_size: null
<       max_cache_size: 1
<       obs_columns_to_validate: null
---
>       filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
>       shard_size: 100
>       max_cache_size: 2
>       obs_columns_to_validate:
>         - total_mrna_umis

Configure the DataLoader. batch_keys have to include all input arguments to the transforms and the model. For example, NormalizeTotal’s arguments are x_ng and total_mrna_umis_n, Log1p’s argument is x_ng, and OnePassMeanVarStd’s arguments are x_ng and var_names_g:

<   batch_keys: null
<   batch_size: 1
<   num_workers: 0
---
>   batch_keys:
>     x_ng:
>       attr: X
>       convert_fn: cellarium.ml.utilities.data.densify
>     var_names_g:
>       attr: var_names
>     total_mrna_umis_n:
>       attr: obs
>       key: total_mrna_umis
>   batch_size: 100
>   num_workers: 2

2. IncrementalPCA

Generate a default config file:

cellarium-ml incremental_pca fit --print_config > ipca_train_config.yaml

Below we highlight the changes made to the default configuration file.

train

Change the number of devices and set the path for logs and weights:

<   devices: auto
<   default_root_dir: null
---
>   devices: 2
>   default_root_dir: /tmp/test_examples/ipca

model

Add NormalizeTotal and Log1p, and ZScore transforms. Note, that mean_g, std_g, and var_names_g of ZScore transform are loaded from the OnePassMeanVarStd checkpoint:

Note

cellarium-ml does not perform any validation on the transforms being applied to the data. Please, always verify it yourself that the transforms are configured correctly. If not configured correctly, your model will silently produce wrong results. In the example below, we first apply NormalizeTotal and Log1p transforms to the data and then apply ZScore transform. Importantly, mean_g and std_g parameters of the ZScore transform were calculated using OnePassMeanVarStd model on the data that was also transformed with NormalizeTotal and Log1p.

<   transforms: null
---
>   transforms:
>     - class_path: cellarium.ml.transforms.NormalizeTotal
>       init_args:
>         target_count: 10_000
>     - cellarium.ml.transforms.Log1p
>     - class_path: cellarium.ml.transforms.ZScore
>       init_args:
>         mean_g:
>           !CheckpointLoader
>           file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
>           attr: model.mean_g
>           convert_fn: null
>         std_g:
>           !CheckpointLoader
>           file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
>           attr: model.std_g
>           convert_fn: null
>         var_names_g:
>           !CheckpointLoader
>           file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
>           attr: model.var_names_g
>           convert_fn: numpy.ndarray.tolist

Set the number of components for IncrementalPCA:

<       n_components: null
---
>       n_components: 50

data

Configure the DistributedAnnDataCollection. Here we validate obs columns that are used by the transforms and the model (total_mrna_umis):

<       filenames: null
<       shard_size: null
<       max_cache_size: 1
<       obs_columns_to_validate: null
---
>       filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
>       shard_size: 100
>       max_cache_size: 2
>       obs_columns_to_validate:
>         - total_mrna_umis

Configure the DataLoader. batch_keys contains the same keys as for OnePassMeanVarStd above:

<   batch_keys: null
<   batch_size: 1
<   num_workers: 0
---
>   batch_keys:
>     x_ng:
>       attr: X
>       convert_fn: cellarium.ml.utilities.data.densify
>     var_names_g:
>       attr: var_names
>     total_mrna_umis_n:
>       attr: obs
>       key: total_mrna_umis
>   batch_size: 100
>   num_workers: 2

3. LogisticRegression

Generate a default config file:

cellarium-ml logistic_regression fit --print_config > lr_train_config.yaml

Below we highlight the changes made to the default configuration file.

train

Change the number of devices, checkpoint the model every iteration step, set the number of epochs, log after every iteration step, and set the path for logs and weights:

<   devices: auto
<   callbacks: null
<   max_epochs: null
<   log_every_n_steps: null
<   default_root_dir: null
---
>   devices: 2
>   callbacks:
>   - class_path: lightning.pytorch.callbacks.ModelCheckpoint
>       init_args:
>       every_n_train_steps: 1
>       save_top_k: -1
>   max_epochs: 5
>   log_every_n_steps: 1
>   default_root_dir: /tmp/test_examples/lr

model

Add trained PCA model as a transform. Note, that the trained PCA model contains NormalizeTotal and Log1p, and ZScore transforms in its pipeline:

<   transforms: null
---
>   transforms:
>     - !CheckpointLoader
>       file_path: /tmp/test_examples/ipca/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
>       attr: null
>       convert_fn: null

Set the optimizer and its learning rate:

<   optim_fn: null
<   optim_kwargs: null
---
>   optim_fn: torch.optim.Adam
>   optim_kwargs:
>     lr: 0.1

data

Configure the DistributedAnnDataCollection. Here we validate obs columns that are used by the transforms and the model (total_mrna_umis as above and additionally assay column):

<       filenames: null
<       shard_size: null
<       max_cache_size: 1
<       obs_columns_to_validate: null
---
>       filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
>       shard_size: 100
>       max_cache_size: 2
>       obs_columns_to_validate:
>         - total_mrna_umis
>         - assay

Configure the DataLoader. batch_keys contains the same keys as above and additionally y_n which is an argument to LogisticRegression:

<   batch_keys: null
<   batch_size: 1
<   num_workers: 0
---
>   batch_keys:
>     x_ng:
>       attr: X
>       convert_fn: cellarium.ml.utilities.data.densify
>     var_names_g:
>       attr: var_names
>     total_mrna_umis_n:
>       attr: obs
>       key: total_mrna_umis
>     y_n:
>       attr: obs
>       key: assay
>       convert_fn: cellarium.ml.utilities.data.categories_to_codes
>   batch_size: 25
>   num_workers: 2

4. Resume training

In order to resume training of the logistic regression model from a saved checkpoint add the checkpoint filepath to the config file:

< ckpt_path: null
---
> ckpt_path: /tmp/test_examples/lr/lightning_logs/version_0/checkpoints/epoch=1-step=13.ckpt