Example CLI workflow
This example workflow demonstrates how to:
Compute the mean and standard deviation of normalized and log1p-transformed data.
Perform PCA on data that has been z-scored based on the mean and standard deviation from step 1.
Train a logistic regression classifier on the principal component embeddings of the data.
To execute the workflow, download the config files from here and run the scripts in sequence:
cellarium-ml onepass_mean_var_std fit --config onepass_train_config.yaml
cellarium-ml incremental_pca fit --config ipca_train_config.yaml
cellarium-ml logistic_regression fit --config lr_train_config.yaml
cellarium-ml logistic_regression fit --config lr_resume_train_config.yaml
Below we explain how the config files were created and what changes were made to the default configuration.
1. OnePassMeanVarStd
Generate a default config file:
cellarium-ml onepass_mean_var_std fit --print_config > onepass_train_config.yaml
Below we highlight the changes made to the default configuration file.
trainer
Change the number of devices and set the path for logs and weights:
< devices: auto
< default_root_dir: null
---
> devices: 2
> default_root_dir: /tmp/test_examples/onepass
model
Add NormalizeTotal
and Log1p
transforms:
< transforms: null
---
> transforms:
> - class_path: cellarium.ml.transforms.NormalizeTotal
> init_args:
> target_count: 10_000
> - cellarium.ml.transforms.Log1p
Change OnePassMeanVarStd
’s algorithm to shifted_data
:
< algorithm: naive
---
> algorithm: shifted_data
data
Configure the DistributedAnnDataCollection
. Here we validate obs
columns that are used by the transforms and the model (total_mrna_umis
). Validation is done for each loaded AnnData file against the first (reference) AnnData file by checking that column names and dtypes match between the two:
< filenames: null
< shard_size: null
< max_cache_size: 1
< obs_columns_to_validate: null
---
> filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
> shard_size: 100
> max_cache_size: 2
> obs_columns_to_validate:
> - total_mrna_umis
Configure the DataLoader
. batch_keys
have to include all input arguments to the transforms and the model. For example, NormalizeTotal
’s arguments are x_ng
and total_mrna_umis_n
, Log1p
’s argument is x_ng
, and OnePassMeanVarStd
’s arguments are x_ng
and var_names_g
:
< batch_keys: null
< batch_size: 1
< num_workers: 0
---
> batch_keys:
> x_ng:
> attr: X
> convert_fn: cellarium.ml.utilities.data.densify
> var_names_g:
> attr: var_names
> total_mrna_umis_n:
> attr: obs
> key: total_mrna_umis
> batch_size: 100
> num_workers: 2
2. IncrementalPCA
Generate a default config file:
cellarium-ml incremental_pca fit --print_config > ipca_train_config.yaml
Below we highlight the changes made to the default configuration file.
train
Change the number of devices and set the path for logs and weights:
< devices: auto
< default_root_dir: null
---
> devices: 2
> default_root_dir: /tmp/test_examples/ipca
model
Add NormalizeTotal
and Log1p
, and ZScore
transforms. Note, that mean_g
, std_g
, and var_names_g
of ZScore
transform are loaded from the OnePassMeanVarStd
checkpoint:
Note
cellarium-ml
does not perform any validation on the transforms being applied to the data. Please, always verify it yourself that the transforms are configured correctly. If not configured correctly, your model will silently produce wrong results. In the example below, we first apply NormalizeTotal
and Log1p
transforms to the data and then apply ZScore
transform. Importantly, mean_g
and std_g
parameters of the ZScore
transform were calculated using OnePassMeanVarStd
model on the data that was also transformed with NormalizeTotal
and Log1p
.
< transforms: null
---
> transforms:
> - class_path: cellarium.ml.transforms.NormalizeTotal
> init_args:
> target_count: 10_000
> - cellarium.ml.transforms.Log1p
> - class_path: cellarium.ml.transforms.ZScore
> init_args:
> mean_g:
> !CheckpointLoader
> file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: model.mean_g
> convert_fn: null
> std_g:
> !CheckpointLoader
> file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: model.std_g
> convert_fn: null
> var_names_g:
> !CheckpointLoader
> file_path: /tmp/test_examples/onepass/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: model.var_names_g
> convert_fn: numpy.ndarray.tolist
Set the number of components for IncrementalPCA
:
< n_components: null
---
> n_components: 50
data
Configure the DistributedAnnDataCollection
. Here we validate obs
columns that are used by the transforms and the model (total_mrna_umis
):
< filenames: null
< shard_size: null
< max_cache_size: 1
< obs_columns_to_validate: null
---
> filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
> shard_size: 100
> max_cache_size: 2
> obs_columns_to_validate:
> - total_mrna_umis
Configure the DataLoader
. batch_keys
contains the same keys as for OnePassMeanVarStd
above:
< batch_keys: null
< batch_size: 1
< num_workers: 0
---
> batch_keys:
> x_ng:
> attr: X
> convert_fn: cellarium.ml.utilities.data.densify
> var_names_g:
> attr: var_names
> total_mrna_umis_n:
> attr: obs
> key: total_mrna_umis
> batch_size: 100
> num_workers: 2
3. LogisticRegression
Generate a default config file:
cellarium-ml logistic_regression fit --print_config > lr_train_config.yaml
Below we highlight the changes made to the default configuration file.
train
Change the number of devices, checkpoint the model every iteration step, set the number of epochs, log after every iteration step, and set the path for logs and weights:
< devices: auto
< callbacks: null
< max_epochs: null
< log_every_n_steps: null
< default_root_dir: null
---
> devices: 2
> callbacks:
> - class_path: lightning.pytorch.callbacks.ModelCheckpoint
> init_args:
> every_n_train_steps: 1
> save_top_k: -1
> max_epochs: 5
> log_every_n_steps: 1
> default_root_dir: /tmp/test_examples/lr
model
Add trained PCA model as a transform. Note, that the trained PCA model contains NormalizeTotal
and Log1p
, and ZScore
transforms in its pipeline:
< transforms: null
---
> transforms:
> - !CheckpointLoader
> file_path: /tmp/test_examples/ipca/lightning_logs/version_0/checkpoints/epoch=0-step=2.ckpt
> attr: null
> convert_fn: null
Set the optimizer and its learning rate:
< optim_fn: null
< optim_kwargs: null
---
> optim_fn: torch.optim.Adam
> optim_kwargs:
> lr: 0.1
data
Configure the DistributedAnnDataCollection
. Here we validate obs
columns that are used by the transforms and the model (total_mrna_umis
as above and additionally assay
column):
< filenames: null
< shard_size: null
< max_cache_size: 1
< obs_columns_to_validate: null
---
> filenames: https://storage.googleapis.com/dsp-cellarium-cas-public/test-data/test_{0..3}.h5ad
> shard_size: 100
> max_cache_size: 2
> obs_columns_to_validate:
> - total_mrna_umis
> - assay
Configure the DataLoader
. batch_keys
contains the same keys as above and additionally y_n
which is an argument to LogisticRegression
:
< batch_keys: null
< batch_size: 1
< num_workers: 0
---
> batch_keys:
> x_ng:
> attr: X
> convert_fn: cellarium.ml.utilities.data.densify
> var_names_g:
> attr: var_names
> total_mrna_umis_n:
> attr: obs
> key: total_mrna_umis
> y_n:
> attr: obs
> key: assay
> convert_fn: cellarium.ml.utilities.data.categories_to_codes
> batch_size: 25
> num_workers: 2
4. Resume training
In order to resume training of the logistic regression model from a saved checkpoint add the checkpoint filepath to the config file:
< ckpt_path: null
---
> ckpt_path: /tmp/test_examples/lr/lightning_logs/version_0/checkpoints/epoch=1-step=13.ckpt