Set up pretrained Geneformer model and make predictions
[ ]:
import os
import pickle
import tempfile
import numpy as np
import torch
from transformers import AutoModel
from cellarium.ml.core import CellariumPipeline
from cellarium.ml.models import Geneformer
from cellarium.ml.transforms import DivideByScale, NormalizeTotal
[ ]:
def get_pretrained_geneformer_pipeline(device) -> CellariumPipeline:
with tempfile.TemporaryDirectory() as tmpdir:
os.system(
f"wget -O {os.path.join(tmpdir, 'token_dictionary.pkl')} https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/token_dictionary.pkl"
)
with open(os.path.join(tmpdir, "token_dictionary.pkl"), "rb") as f:
token_dict = pickle.load(f)
os.system(
f"wget -O {os.path.join(tmpdir, 'gene_median_dictionary.pkl')} https://huggingface.co/ctheodoris/Geneformer/resolve/main/geneformer/gene_median_dictionary.pkl"
)
with open(os.path.join(tmpdir, "gene_median_dictionary.pkl"), "rb") as f:
gene_median_dict = pickle.load(f)
# obtain var_names_g list from the token dict
token_dict.pop("<pad>")
token_dict.pop("<mask>")
var_names_g = np.array(list(token_dict.keys()))
# obtain non-zero median gene counts
gene_median_g = torch.as_tensor(list(gene_median_dict.values())).to(device)
# load the pre-trained model from the hub
pretrained_model = AutoModel.from_pretrained("ctheodoris/Geneformer")
# construct the Geneformer model
geneformer = Geneformer(var_names_g=var_names_g)
# insert the trained model params
geneformer.bert = pretrained_model
geneformer.to(device)
geneformer.eval()
# construct the pipeline
pipeline = CellariumPipeline(
[
NormalizeTotal(target_count=10_000, eps=0),
DivideByScale(scale_g=gene_median_g, var_names_g=var_names_g, eps=0),
geneformer,
]
)
return pipeline
[ ]:
pipeline = get_pretrained_geneformer_pipeline(device="cuda" if torch.cuda.is_available() else "cpu")
pipeline
[ ]:
# n_genes in trained model
var_names_g = pipeline[-1].var_names_g # pipeline[-1] is the Geneformer model (after normalization steps)
n_genes = var_names_g.shape[0]
n_genes
[ ]:
# fake some data
n = 4
x_ng = (
torch.distributions.poisson.Poisson(torch.distributions.dirichlet.Dirichlet(torch.tensor([0.01])).sample([n_genes]))
.sample([n])
.squeeze()
.to("cuda" if torch.cuda.is_available() else "cpu")
)
x_ng.shape
[ ]:
# normal prediction
batch = {"x_ng": x_ng, "var_names_g": var_names_g}
pipeline.predict(batch)
[ ]:
# delete a feature (expression to zero)
batch = {"x_ng": x_ng, "var_names_g": var_names_g, "feature_deletion": ["ENSG00000000005"]}
pipeline.predict(batch)
[ ]:
# mask gene ENSG00000000005
batch = {"x_ng": x_ng, "var_names_g": var_names_g, "feature_map": {"ENSG00000000005": 1}}
pipeline.predict(batch)