{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Set up pretrained Geneformer model and make predictions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pickle\n", "import tempfile\n", "\n", "import numpy as np\n", "import torch\n", "from transformers import AutoModel\n", "\n", "from cellarium.ml.core import CellariumPipeline\n", "from cellarium.ml.models import Geneformer\n", "from cellarium.ml.transforms import DivideByScale, NormalizeTotal" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_pretrained_geneformer_pipeline(device) -> CellariumPipeline:\n", " with tempfile.TemporaryDirectory() as tmpdir:\n", " os.system(\n", " f\"wget -O {os.path.join(tmpdir, 'token_dictionary.pkl')} https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/token_dictionary.pkl\"\n", " )\n", " with open(os.path.join(tmpdir, \"token_dictionary.pkl\"), \"rb\") as f:\n", " token_dict = pickle.load(f)\n", " os.system(\n", " f\"wget -O {os.path.join(tmpdir, 'gene_median_dictionary.pkl')} https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_median_dictionary.pkl\"\n", " )\n", " with open(os.path.join(tmpdir, \"gene_median_dictionary.pkl\"), \"rb\") as f:\n", " gene_median_dict = pickle.load(f)\n", "\n", " # obtain var_names_g list from the token dict\n", " token_dict.pop(\"\")\n", " token_dict.pop(\"\")\n", " var_names_g = np.array(list(token_dict.keys()))\n", "\n", " # obtain non-zero median gene counts\n", " gene_median_g = torch.as_tensor(list(gene_median_dict.values())).to(device)\n", "\n", " # load the pre-trained model from the hub\n", " pretrained_model = AutoModel.from_pretrained(\"ctheodoris/Geneformer\")\n", "\n", " # construct the Geneformer model\n", " geneformer = Geneformer(var_names_g=var_names_g)\n", "\n", " # insert the trained model params\n", " geneformer.bert = pretrained_model\n", " geneformer.to(device)\n", " geneformer.eval()\n", "\n", " # construct the pipeline\n", " pipeline = CellariumPipeline(\n", " [\n", " NormalizeTotal(target_count=10_000, eps=0),\n", " DivideByScale(scale_g=gene_median_g, var_names_g=var_names_g, eps=0),\n", " geneformer,\n", " ]\n", " )\n", "\n", " return pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pipeline = get_pretrained_geneformer_pipeline(device=\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "pipeline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# n_genes in trained model\n", "var_names_g = pipeline[-1].var_names_g # pipeline[-1] is the Geneformer model (after normalization steps)\n", "n_genes = var_names_g.shape[0]\n", "n_genes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# fake some data\n", "n = 4\n", "\n", "x_ng = (\n", " torch.distributions.poisson.Poisson(torch.distributions.dirichlet.Dirichlet(torch.tensor([0.01])).sample([n_genes]))\n", " .sample([n])\n", " .squeeze()\n", " .to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", ")\n", "\n", "x_ng.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# normal prediction\n", "batch = {\"x_ng\": x_ng, \"var_names_g\": var_names_g}\n", "pipeline.predict(batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# delete a feature (expression to zero)\n", "batch = {\"x_ng\": x_ng, \"var_names_g\": var_names_g, \"feature_deletion\": [\"ENSG00000000005\"]}\n", "pipeline.predict(batch)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# mask gene ENSG00000000005\n", "batch = {\"x_ng\": x_ng, \"var_names_g\": var_names_g, \"feature_map\": {\"ENSG00000000005\": 1}}\n", "pipeline.predict(batch)" ] } ], "metadata": { "kernelspec": { "display_name": "cellarium", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }