{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Part I: Preprocessing\n", "\n", "In this tutorial we will demonstrate how `cellarium-ml` can be used to compute mean and variance, identify highly variable genes, and perform dimensionality reduction with PCA." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import lightning.pytorch as pl\n", "import numpy as np\n", "import pandas as pd\n", "import scanpy as sc\n", "\n", "from cellarium.ml import CellariumAnnDataDataModule, CellariumModule\n", "from cellarium.ml.data import read_h5ad_file\n", "from cellarium.ml.models import IncrementalPCA, OnePassMeanVarStd\n", "from cellarium.ml.preprocessing import get_highly_variable_genes\n", "from cellarium.ml.transforms import Filter, Log1p, NormalizeTotal, ZScore\n", "from cellarium.ml.utilities.core import resolve_ckpt_dir\n", "from cellarium.ml.utilities.data import AnnDataField, collate_fn, densify\n", "\n", "sc.settings.set_figure_params(dpi=80, facecolor=\"white\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download dataset\n", "\n", "As an example dataset we will use [*A (Balanced) Bone Marrow Reference Map of Hematopoietic Development*](https://cellxgene.cziscience.com/collections/f6c50495-3361-40ed-a819-fb9644396ed9) freely availabe from CELLxGENE." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "adata = read_h5ad_file(\"https://datasets.cellxgene.cziscience.com/8674c375-ae3a-433c-97de-3c56cf8f7304.h5ad\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compute the mean and variance\n", "\n", "We will calculate the per-gene mean and variance of the normalized and log1p transformed counts. The computation is performed in one-pass by iterating over mini-batches (of size 10,000 cells) and using the [shifted data algorithm](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) to compute the mean and variance. We will use `pl.Trainer` to orchestrate the training which under the hood does the following (pseudocode):\n", "\n", "```py\n", "# instantiate the module\n", "onepass_module = CellariumModule(transforms=..., model=...)\n", "# configure_model creates a CellariumPipeline consisting of the transforms and the model\n", "onepass_module.configure_model()\n", "\n", "# instantiate the datamodule\n", "datamodule = CellariumAnnDataDataModule(dadc=..., batch_keys=..., batch_size=..., ...)\n", "# setup the iterable dataset\n", "datamodule.setup(stage=\"fit\")\n", "\n", "# training loop\n", "# batch dictionary has the same keys as the batch_keys above\n", "for batch_idx, batch in enumerate(datamodule.train_dataloader()):\n", " # batch is passed through the transforms and the model in the pipeline\n", " onepass_module.pipeline(batch)\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's instantiate the `CellariumModule` consisting of transforms (`NormalizeTotal` and `Log1p`) and the model (`OnePassMeanVarStd`):" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "onepass_module = CellariumModule(\n", " transforms=[\n", " NormalizeTotal(target_count=10_000),\n", " Log1p(),\n", " ],\n", " model=OnePassMeanVarStd(var_names_g=adata.var_names, algorithm=\"shifted_data\"),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once configured by the `pl.Trainer`, `onepass_module.pipeline` will be a `CellariumPipeline` consisting of a list of `NormalizeTotal`, `Log1p`, `OnePassMeanVarStd` sub-modules. Note, that `CellariumPipeline` is a sub-class of `nn.ModuleList` and its sub-modules can be accessed via slicing. Its forward method accepts a batch dictionary from the dataloader as an input and then runs its sub-modules sequentially by giving them correct arguments from the batch dictionary. The pseudocode looks like this:\n", "\n", "```py\n", "# CellariumPipeline forward method\n", "def forward(self, batch: dict) -> dict:\n", " # the first sub-module is NormalizeTotal\n", " out = self[0](x_ng=batch[\"x_ng\"], total_mrna_umis_n=batch[\"total_mrna_umis_n\"])\n", " # overwrite `x_ng` key in the batch with the normalized counts\n", " batch.update(out)\n", "\n", " # the second sub-module is Log1p\n", " out = self[1](x_ng=batch[\"x_ng\"])\n", " # overwrite `x_ng` key in the batch with the log1p transformed counts\n", " batch.update(out)\n", "\n", " # the third sub-module is OnePassMeanVarStd which processes x_ng and updates its internal state\n", " out = self[2](x_ng=batch[\"x_ng\"], var_names_g=batch[\"var_names_g\"])\n", " # this is a no-op because OnePassMeanVarStd returns an empty dictionary\n", " batch.update(out)\n", "\n", " return batch\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we will create `CellariumAnnDataDataModule` which is used by the trainer to create the dataloader during training. The `dadc` argument specifies the `AnnData` object to iterate over. However, note that when working with multiple anndata files `DistributedAnnDataCollection` object can also be used. Next, the `batch_keys` dictionary specifies the content of the generated batches. Its keys must match the names of the arguments of the transforms and the model as we have seen above. Namely, in this case it should contain `x_ng`, `total_mrna_umis_n`, and `var_names_g` keys. Its values are `AnnDataField` objects that the datamodule uses to retrieve batch arguments from the anndata object. For example, `x_ng` is obtained by accesing the `raw.X` attribute of an anndata object and then densifying it by using the `cellarium.ml.utilities.data.densify` function. Next, the `batch_size` can be set to a high number since `OnePassMeanVarStd` model is very fast and is not memory intensive. Finally, the `num_workers` should also be set to a high number to ensure that the dataloader doesn't become a bottleneck during training." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "datamodule = CellariumAnnDataDataModule(\n", " dadc=adata,\n", " batch_keys={\n", " \"x_ng\": AnnDataField(attr=\"raw.X\", convert_fn=densify),\n", " \"var_names_g\": AnnDataField(attr=\"var_names\"),\n", " \"total_mrna_umis_n\": AnnDataField(attr=\"obs\", key=\"nCount_RNA\"),\n", " },\n", " batch_size=10_000,\n", " num_workers=8,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, initialize `pl.Trainer` by specifying it to use 1 `gpu` accelerator, iterate over entire dataset once, and the directory where to save the checkpoint. Lastly, run the computation by invoking the `fit` method of the trainer." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/yordabay/anaconda3/envs/cellarium/lib/python3.10/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/yordabay/anaconda3/envs/cellarium/lib/python3. ...\n", "GPU available: True (cuda), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n", "/home/yordabay/anaconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:182: `LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer\n", "\n", " | Name | Type | Params | Mode \n", "-------------------------------------------------------\n", "0 | pipeline | CellariumPipeline | 1 | train\n", "-------------------------------------------------------\n", "1 Trainable params\n", "0 Non-trainable params\n", "1 Total params\n", "0.000 Total estimated model params size (MB)\n", "4 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2fbbc346ebcc4016bad58355b95c9900", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Sanity Checking: | | 0/? [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "/mnt/disks/dev/repos/cellarium-ml/cellarium/ml/utilities/distributed.py:52: UserWarning: Distributed package is available but the default process group has not been initialized. Falling back to ``rank=0`` and ``num_replicas=1``.\n", " warnings.warn(\n", "/home/yordabay/anaconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:105: Total length of `DataLoader` across ranks is zero. Please make sure this was your intention.\n", "/home/yordabay/anaconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/utilities/data.py:122: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.\n", "/home/yordabay/anaconda3/envs/cellarium/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (27) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "29155c2b51164b22b5361efa07442f3a", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: | | 0/? [00:00, ?it/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=1` reached.\n" ] } ], "source": [ "trainer = pl.Trainer(\n", " accelerator=\"gpu\",\n", " devices=1,\n", " max_epochs=1,\n", " default_root_dir=\"runs/onepass\",\n", ")\n", "trainer.fit(onepass_module, datamodule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Convert the computed statistics into a dataframe and make a scatter-plot of the mean vs. variance." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | mean | \n", "variance | \n", "n_samples | \n", "
---|---|---|---|
ENSG00000177757 | \n", "0.000474 | \n", "0.000702 | \n", "263159.0 | \n", "
ENSG00000225880 | \n", "0.018753 | \n", "0.027754 | \n", "263159.0 | \n", "
ENSG00000187634 | \n", "0.001644 | \n", "0.002092 | \n", "263159.0 | \n", "
ENSG00000188976 | \n", "0.267763 | \n", "0.335251 | \n", "263159.0 | \n", "
ENSG00000187961 | \n", "0.005226 | \n", "0.008233 | \n", "263159.0 | \n", "