Set up pretrained Geneformer model and make predictions

[ ]:
import os
import pickle
import tempfile

import numpy as np
import torch
from transformers import AutoModel

from cellarium.ml.core import CellariumPipeline
from cellarium.ml.models import Geneformer
from cellarium.ml.transforms import DivideByScale, NormalizeTotal
[ ]:
def get_pretrained_geneformer_pipeline(device) -> CellariumPipeline:
    with tempfile.TemporaryDirectory() as tmpdir:
        os.system(
            f"wget -O {os.path.join(tmpdir, 'token_dictionary.pkl')} https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/token_dictionary.pkl"
        )
        with open(os.path.join(tmpdir, "token_dictionary.pkl"), "rb") as f:
            token_dict = pickle.load(f)
        os.system(
            f"wget -O {os.path.join(tmpdir, 'gene_median_dictionary.pkl')} https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_median_dictionary.pkl"
        )
        with open(os.path.join(tmpdir, "gene_median_dictionary.pkl"), "rb") as f:
            gene_median_dict = pickle.load(f)

    # obtain var_names_g list from the token dict
    token_dict.pop("<pad>")
    token_dict.pop("<mask>")
    var_names_g = np.array(list(token_dict.keys()))

    # obtain non-zero median gene counts
    gene_median_g = torch.as_tensor(list(gene_median_dict.values())).to(device)

    # load the pre-trained model from the hub
    pretrained_model = AutoModel.from_pretrained("ctheodoris/Geneformer")

    # construct the Geneformer model
    geneformer = Geneformer(var_names_g=var_names_g)

    # insert the trained model params
    geneformer.bert = pretrained_model
    geneformer.to(device)
    geneformer.eval()

    # construct the pipeline
    pipeline = CellariumPipeline(
        [
            NormalizeTotal(target_count=10_000, eps=0),
            DivideByScale(scale_g=gene_median_g, var_names_g=var_names_g, eps=0),
            geneformer,
        ]
    )

    return pipeline
[ ]:
pipeline = get_pretrained_geneformer_pipeline(device="cuda" if torch.cuda.is_available() else "cpu")
pipeline
[ ]:
# n_genes in trained model
var_names_g = pipeline[-1].var_names_g  # pipeline[-1] is the Geneformer model (after normalization steps)
n_genes = var_names_g.shape[0]
n_genes
[ ]:
# fake some data
n = 4

x_ng = (
    torch.distributions.poisson.Poisson(torch.distributions.dirichlet.Dirichlet(torch.tensor([0.01])).sample([n_genes]))
    .sample([n])
    .squeeze()
    .to("cuda" if torch.cuda.is_available() else "cpu")
)

x_ng.shape
[ ]:
# normal prediction
batch = {"x_ng": x_ng, "var_names_g": var_names_g}
pipeline.predict(batch)
[ ]:
# delete a feature (expression to zero)
batch = {"x_ng": x_ng, "var_names_g": var_names_g, "feature_deletion": ["ENSG00000000005"]}
pipeline.predict(batch)
[ ]:
# mask gene ENSG00000000005
batch = {"x_ng": x_ng, "var_names_g": var_names_g, "feature_map": {"ENSG00000000005": 1}}
pipeline.predict(batch)