Source code for cellarium.ml.transforms.cellarium_gpt_tokenizer

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pandas as pd
import torch


[docs] class CellariumGPTTrainTokenizer(torch.nn.Module): """ Tokenizer for the Cellarium GPT model. Args: context_len: Context length. gene_downsample_fraction: Fraction of genes to downsample. min_total_mrna_umis: Minimum total mRNA UMIs. max_total_mrna_umis: Maximum total mRNA UMIs. gene_vocab_sizes: Gene token vocabulary sizes. metadata_vocab_sizes: Metadata token vocabulary sizes. ontology_infos_path: Path to ontology information. prefix_len: Prefix length. If ``None``, the prefix length is sampled. metadata_prompt_token_list: List of metadata tokens to prompt. If ``None``, the metadata prompt tokens are sampled. obs_names_rng: Cell IDs are used as random seeds for shuffling gene tokens. If ``None``, gene tokens are shuffled without a random seed. """ def __init__( self, context_len: int, gene_downsample_fraction: float, min_total_mrna_umis: int, max_total_mrna_umis: int, gene_vocab_sizes: dict[str, int], metadata_vocab_sizes: dict[str, int], ontology_downsample_p: float, ontology_infos_path: str, prefix_len: int | None = None, metadata_prompt_token_list: list[str] | None = None, obs_names_rng: bool = False, ) -> None: super().__init__() self.context_len = context_len self.gene_downsample_fraction = gene_downsample_fraction self.min_total_mrna_umis = min_total_mrna_umis self.max_total_mrna_umis = max_total_mrna_umis self.gene_vocab_sizes = gene_vocab_sizes self.metadata_vocab_sizes = metadata_vocab_sizes self.ontology_infos = torch.load(ontology_infos_path, weights_only=True) self.ontology_downsample_p = ontology_downsample_p self.prefix_len = prefix_len self.metadata_prompt_token_list = metadata_prompt_token_list self.obs_names_rng = obs_names_rng def forward( self, metadata_token_n_dict: dict[str, torch.Tensor], gene_token_n_dict: dict[str, torch.Tensor], gene_token_ng_dict: dict[str, torch.Tensor], obs_names_n: np.ndarray | None = None, ) -> dict[str, dict[str, torch.Tensor] | torch.Tensor]: ### GENE TOKENS ### n, g = gene_token_ng_dict["gene_value"].shape m = len(metadata_token_n_dict) c = self.context_len # gene context length j = c - m device = gene_token_ng_dict["gene_value"].device ## gene measurement tokens (assay, suspension type, etc.) ## gene_token_nj_dict = {key: gene_token_n_dict[key][:, None].expand(-1, j).int() for key in gene_token_n_dict} ## gene id ## gene_token_ng_dict["gene_id"] = torch.arange(g, device=device).expand(n, g) if self.obs_names_rng: assert obs_names_n is not None rng_n = [torch.Generator(device=device) for _ in range(n)] [rng.manual_seed(int(obs_name)) for rng, obs_name in zip(rng_n, obs_names_n)] shuffle_idx_ng = torch.stack([torch.randperm(g, generator=rng, device=device) for rng in rng_n]) else: shuffle_idx_ng = torch.argsort(torch.rand((n, g), dtype=torch.float32, device=device), dim=-1) shuffle_idx_nj = shuffle_idx_ng[:, :j] for key, gene_token_ng in gene_token_ng_dict.items(): gene_token_nj_dict[key] = torch.gather(gene_token_ng, dim=-1, index=shuffle_idx_nj) ## gene value ## gene_value_nj = gene_token_nj_dict.pop("gene_value") total_mrna_umis_nj = gene_token_nj_dict.pop("total_mrna_umis") # downsample gene values max_total_mrna_umis = torch.tensor(self.max_total_mrna_umis, device=device) downsampled_total_mrna_umis_nj = torch.minimum(total_mrna_umis_nj, max_total_mrna_umis).float() if self.gene_downsample_fraction > 0: gene_downsample_p_nj = torch.minimum( torch.rand((n, j), device=device) / self.gene_downsample_fraction, torch.tensor(1.0, device=device), ) downsampled_total_mrna_umis_nj = torch.lerp( torch.full_like(gene_downsample_p_nj, self.min_total_mrna_umis), downsampled_total_mrna_umis_nj, gene_downsample_p_nj, ) gene_downsample_p_nj = downsampled_total_mrna_umis_nj / total_mrna_umis_nj gene_value_nj = torch.binomial(gene_value_nj, gene_downsample_p_nj) total_mrna_umis_nj = torch.round(downsampled_total_mrna_umis_nj) if self.prefix_len is not None: prefix_len_n = torch.full((n,), self.prefix_len, dtype=torch.float32) else: # sample prefix length # prefix_len_weights = [1, max_prefix_len / 2, max_prefix_len / 3, ..., max_prefix_len / max_prefix_len] max_prefix_len = j - 1 prefix_len_weights = 1 / torch.arange(max_prefix_len + 1, dtype=torch.float32) prefix_len_weights[0] = 1 / 10 prefix_len_n = torch.multinomial(prefix_len_weights, n, replacement=True) # create prompt and query masks gene_query_mask_nj = torch.arange(j, device=device) >= prefix_len_n[:, None].expand(n, -1) gene_prompt_mask_nj = ~gene_query_mask_nj if "measured_genes_mask" in gene_token_nj_dict: measured_genes_mask_nj = gene_token_nj_dict.pop("measured_genes_mask") gene_query_mask_nj = gene_query_mask_nj & measured_genes_mask_nj gene_prompt_mask_nj = gene_prompt_mask_nj & measured_genes_mask_nj gene_token_nj_dict["gene_value"] = torch.log1p(gene_value_nj) * gene_prompt_mask_nj.float() gene_token_nj_dict["gene_query_mask"] = gene_query_mask_nj.float() gene_token_nj_dict["total_mrna_umis"] = torch.log1p(total_mrna_umis_nj) gene_token_value_nc_dict = { key: torch.cat([gene_token_nj, torch.zeros((n, m), device=device, dtype=gene_token_nj.dtype)], dim=1) for key, gene_token_nj in gene_token_nj_dict.items() } gene_token_mask_nc = torch.cat( [torch.ones((n, j), dtype=torch.bool, device=device), torch.zeros((n, m), dtype=torch.bool, device=device)], dim=1, ) gene_token_mask_nc_dict = {key: gene_token_mask_nc for key in gene_token_nj_dict} # gene label gene_value_vocab_size = self.gene_vocab_sizes["gene_value"] gene_label_nj = gene_value_nj.clamp(0, gene_value_vocab_size - 1).int() ### METADATA TOKENS ### ## metadata tokens ## # assign token codes based on the ontology info # token values not in the ontology are treated as unmeasured and assigned a code value of -1 for key, ontology_info in self.ontology_infos.items(): assert self.metadata_vocab_sizes[key] == len(ontology_info["names"]) metadata_token_n_dict[key] = torch.tensor( pd.Categorical(metadata_token_n_dict[key], categories=ontology_info["names"]).codes, dtype=torch.int, ) # create metadata query and prompt masks if self.metadata_prompt_token_list is not None: metadata_prompt_mask_nm = torch.zeros((n, m), dtype=torch.bool, device=device) for metadata_token_idx, metadata_token in enumerate(metadata_token_n_dict): if metadata_token in self.metadata_prompt_token_list: metadata_prompt_mask_nm[:, metadata_token_idx] = True else: metadata_prefix_len_n = torch.randint(0, m + 1, (n,), device=device) metadata_prefix_mask_nm = torch.arange(m, device=device) < metadata_prefix_len_n[:, None] shuffle_idx_nm = torch.argsort(torch.rand_like(metadata_prefix_mask_nm, dtype=torch.float32), dim=-1) metadata_prompt_mask_nm = torch.gather(metadata_prefix_mask_nm, dim=-1, index=shuffle_idx_nm) metadata_query_mask_nm = ~metadata_prompt_mask_nm metadata_measured_mask_nm = torch.stack( [metadata_token_n >= 0 for metadata_token_n in metadata_token_n_dict.values()], dim=1 ).bool() metadata_query_mask_nm = metadata_query_mask_nm & metadata_measured_mask_nm metadata_prompt_mask_nm = metadata_prompt_mask_nm & metadata_measured_mask_nm # clamp unmeasured tokens to 0 in order to avoid error during embedding # the value of unmeasured tokens doesn't matter since they will be masked out by the attention mask for key, metadata_token_n in metadata_token_n_dict.items(): metadata_token_n_dict[key] = metadata_token_n.clamp(0).int() # metadata labels metadata_label_n_dict = {key: metadata_token_n_dict[key].clone() for key in metadata_token_n_dict} if self.ontology_downsample_p != 0: # downsample metadata based on ontology for key, ontology_info in self.ontology_infos.items(): if "shortest_distances_matrix" not in ontology_info: continue metadata_token_n = metadata_token_n_dict[key] shortest_distances_matrix = ontology_info["shortest_distances_matrix"] ontology_weights = ( self.ontology_downsample_p * (1 - self.ontology_downsample_p) ** shortest_distances_matrix ) metadata_token_n_dict[key] = ( torch.multinomial(ontology_weights[metadata_token_n], num_samples=1).squeeze(-1).int() ) # impute mask token for unmeasured metadata # mask token is the last token in the vocabulary for i, (key, metadata_token_n) in enumerate(metadata_token_n_dict.items()): metadata_token_n_dict[key] = torch.where( metadata_query_mask_nm[:, i], self.metadata_vocab_sizes[key], metadata_token_n ).int() block_metadata_token_nm = torch.block_diag( *[metadata_token_n_dict[key].unsqueeze(-1) for key in metadata_token_n_dict], ) metadata_token_value_nc_dict = { key: torch.cat( [torch.zeros((n, j), dtype=torch.int, device=device), block_metadata_token_nm[n * i : n * (i + 1)]], dim=1, ) for i, key in enumerate(metadata_token_n_dict) } block_metadata_token_mask_nm = torch.block_diag( *[torch.ones((n, 1), dtype=torch.bool, device=device) for _ in metadata_token_n_dict], ) metadata_token_mask_nc_dict = { key: torch.cat( [ torch.zeros((n, j), dtype=torch.bool, device=device), block_metadata_token_mask_nm[n * i : n * (i + 1)], ], dim=1, ) for i, key in enumerate(metadata_token_n_dict) } ### PROMPT MASK ### prompt_mask_nc = torch.cat([gene_prompt_mask_nj, metadata_prompt_mask_nm], dim=1) ### LABELS ### block_label_nc = torch.block_diag( gene_label_nj, *[metadata_label_n.unsqueeze(-1) for metadata_label_n in metadata_label_n_dict.values()], ) label_nc_dict = { key: block_label_nc[n * i : n * (i + 1)] for i, key in enumerate(["gene_value"] + list(metadata_token_n_dict)) } ### LABEL WEIGHTS ### block_label_weight_nc = ( torch.block_diag( gene_query_mask_nj / torch.maximum(gene_query_mask_nj.sum(dim=-1, keepdim=True), torch.tensor(1.0)), *[metadata_query_mask_nm[:, i].unsqueeze(-1).float() for i in range(m)], ) / n ) label_weight_nc_dict = { key: block_label_weight_nc[n * i : n * (i + 1)] for i, key in enumerate(["gene_value"] + list(metadata_token_n_dict)) } return { "token_value_nc_dict": gene_token_value_nc_dict | metadata_token_value_nc_dict, "token_mask_nc_dict": gene_token_mask_nc_dict | metadata_token_mask_nc_dict, "prompt_mask_nc": prompt_mask_nc, "label_nc_dict": label_nc_dict, "label_weight_nc_dict": label_weight_nc_dict, }