Example CLI workflow
This example workflow demonstrates how to:
Compute the mean and standard deviation of normalized and log1p-transformed data.
Perform PCA on data that has been z-scored based on the mean and standard deviation from step 1.
Train a logistic regression classifier on the principal component embeddings of the data.
To execute the workflow, download the config files from here and run the scripts in sequence:
cellarium-ml onepass_mean_var_std fit --config onepass_train_config.yaml
cellarium-ml incremental_pca fit --config ipca_train_config.yaml
cellarium-ml logistic_regression fit --config lr_train_config.yaml
cellarium-ml logistic_regression fit --config lr_resume_train_config.yaml
Below we explain how the config files were created and what changes were made to the default configuration.
1. OnePassMeanVarStd
Generate a default config file:
cellarium-ml onepass_mean_var_std fit --print_config > onepass_train_config.yaml
Below we highlight the changes made to the default configuration file.
trainer
Change the number of devices and set the path for logs and weights:
< devices: auto
< default_root_dir: null
---
> devices: 2
> default_root_dir: /tmp/test_examples/onepass
model
Add NormalizeTotal and Log1p transforms:
< transforms: null
---
> transforms:
> - class_path: cellarium.ml.transforms.NormalizeTotal
> init_args:
> target_count: 10_000
> - cellarium.ml.transforms.Log1p
Change OnePassMeanVarStd’s algorithm to shifted_data:
< algorithm: naive
---
> algorithm: shifted_data
data
Configure the DistributedAnnDataCollection. Here we validate obs columns that are used by the transforms and the model (total_mrna_umis). Validation is done for each loaded AnnData file against the first (reference) AnnData file by checking that column names and dtypes match between the two:
< filenames: null
< shard_size: null
< max_cache_size: 1
< obs_columns_to_validate: null
---
> filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
> shard_size: 100
> max_cache_size: 2
> obs_columns_to_validate:
> - total_mrna_umis
Configure the DataLoader. batch_keys have to include all input arguments to the transforms and the model. For example, NormalizeTotal’s arguments are x_ng and total_mrna_umis_n, Log1p’s argument is x_ng, and OnePassMeanVarStd’s arguments are x_ng and var_names_g:
< batch_keys: null
< batch_size: 1
< num_workers: 0
---
> batch_keys:
> x_ng:
> attr: X
> convert_fn: cellarium.ml.utilities.data.densify
> var_names_g:
> attr: var_names
> total_mrna_umis_n:
> attr: obs
> key: total_mrna_umis
> batch_size: 100
> num_workers: 2
2. IncrementalPCA
Generate a default config file:
cellarium-ml incremental_pca fit --print_config > ipca_train_config.yaml
Below we highlight the changes made to the default configuration file.
train
Change the number of devices and set the path for logs and weights:
< devices: auto
< default_root_dir: null
---
> devices: 2
> default_root_dir: /tmp/test_examples/ipca
model
Add NormalizeTotal and Log1p, and ZScore transforms. Note, that mean_g, std_g, and var_names_g of ZScore transform are loaded from the OnePassMeanVarStd checkpoint:
Note
cellarium-ml does not perform any validation on the transforms being applied to the data. Please, always verify it yourself that the transforms are configured correctly. If not configured correctly, your model will silently produce wrong results. In the example below, we first apply NormalizeTotal and Log1p transforms to the data and then apply ZScore transform. Importantly, mean_g and std_g parameters of the ZScore transform were calculated using OnePassMeanVarStd model on the data that was also transformed with NormalizeTotal and Log1p.
< transforms: null
---
> transforms:
> - class_path: cellarium.ml.transforms.NormalizeTotal
> init_args:
> target_count: 10_000
> - cellarium.ml.transforms.Log1p
> - class_path: cellarium.ml.transforms.ZScore
> init_args:
> mean_g:
> !CheckpointLoader
> file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: model.mean_g
> convert_fn: null
> std_g:
> !CheckpointLoader
> file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: model.std_g
> convert_fn: null
> var_names_g:
> !CheckpointLoader
> file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: model.var_names_g
> convert_fn: numpy.ndarray.tolist
Set the number of components for IncrementalPCA:
< n_components: null
---
> n_components: 50
data
Configure the DistributedAnnDataCollection. Here we validate obs columns that are used by the transforms and the model (total_mrna_umis):
< filenames: null
< shard_size: null
< max_cache_size: 1
< obs_columns_to_validate: null
---
> filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
> shard_size: 100
> max_cache_size: 2
> obs_columns_to_validate:
> - total_mrna_umis
Configure the DataLoader. batch_keys contains the same keys as for OnePassMeanVarStd above:
< batch_keys: null
< batch_size: 1
< num_workers: 0
---
> batch_keys:
> x_ng:
> attr: X
> convert_fn: cellarium.ml.utilities.data.densify
> var_names_g:
> attr: var_names
> total_mrna_umis_n:
> attr: obs
> key: total_mrna_umis
> batch_size: 100
> num_workers: 2
3. LogisticRegression
Generate a default config file:
cellarium-ml logistic_regression fit --print_config > lr_train_config.yaml
Below we highlight the changes made to the default configuration file.
train
Change the number of devices, checkpoint the model every iteration step, set the number of epochs, log after every iteration step, and set the path for logs and weights:
< devices: auto
< callbacks: null
< max_epochs: null
< log_every_n_steps: null
< default_root_dir: null
---
> devices: 2
> callbacks:
> - class_path: lightning.pytorch.callbacks.ModelCheckpoint
> init_args:
> every_n_train_steps: 1
> save_top_k: -1
> max_epochs: 5
> log_every_n_steps: 1
> default_root_dir: /tmp/test_examples/lr
model
Add trained PCA model as a transform. Note, that the trained PCA model contains NormalizeTotal and Log1p, and ZScore transforms in its pipeline:
< transforms: null
---
> transforms:
> - !CheckpointLoader
> file_path: /tmp/test_examples/ipca/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: null
> convert_fn: null
Set the optimizer and its learning rate:
< optim_fn: null
< optim_kwargs: null
---
> optim_fn: torch.optim.Adam
> optim_kwargs:
> lr: 0.1
data
Configure the DistributedAnnDataCollection. Here we validate obs columns that are used by the transforms and the model (total_mrna_umis as above and additionally assay column):
< filenames: null
< shard_size: null
< max_cache_size: 1
< obs_columns_to_validate: null
---
> filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
> shard_size: 100
> max_cache_size: 2
> obs_columns_to_validate:
> - total_mrna_umis
> - assay
Configure the DataLoader. batch_keys contains the same keys as above and additionally y_n which is an argument to LogisticRegression:
< batch_keys: null
< batch_size: 1
< num_workers: 0
---
> batch_keys:
> x_ng:
> attr: X
> convert_fn: cellarium.ml.utilities.data.densify
> var_names_g:
> attr: var_names
> total_mrna_umis_n:
> attr: obs
> key: total_mrna_umis
> y_n:
> attr: obs
> key: assay
> convert_fn: cellarium.ml.utilities.data.categories_to_codes
> batch_size: 25
> num_workers: 2
4. Resume training
In order to resume training of the logistic regression model from a saved checkpoint add the checkpoint filepath to the config file:
< ckpt_path: null
---
> ckpt_path: /tmp/test_examples/lr/lightning_logs/version_0/checkpoints/epoch=1-step=13.ckpt