Source code for cellarium.ml.layers.embedding

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any

import torch
from torch import nn

from cellarium.ml.utilities.layers import create_initializer


[docs] class GeneExpressionEmbedding(nn.Module): """ Gene embedding. Args: categorical_vocab_sizes: Categorical gene token vocabulary sizes. continuous_vocab_sizes: Continuous gene token vocabulary sizes. d_model: Dimensionality of the embeddings and hidden states. embeddings_initializer: Initializer for the embeddings. """ def __init__( self, categorical_vocab_sizes: dict[str, int], continuous_vocab_sizes: dict[str, int], d_model: int, embeddings_initializer: dict[str, Any], ) -> None: super().__init__() self.E = nn.ModuleDict() self.E.update({key: nn.Embedding(vocab_size, d_model) for key, vocab_size in categorical_vocab_sizes.items()}) self.E.update( {key: nn.Linear(vocab_size, d_model, bias=False) for key, vocab_size in continuous_vocab_sizes.items()} ) self.embeddings_initializer = embeddings_initializer self._reset_parameters() def _reset_parameters(self) -> None: for module in self.E.children(): create_initializer(self.embeddings_initializer)(module.weight)
[docs] def forward(self, gene_tokens_nc: dict[str, torch.Tensor]) -> torch.Tensor: """ Args: gene_tokens_nc: Dictionary of gene token tensors of shape ``(n, c)``. Returns: The gene embedding tensor of shape ``(n, c, d)``. """ return sum(self.E[key](gene_token_nc) for key, gene_token_nc in gene_tokens_nc.items())
[docs] class MetadataEmbedding(nn.Module): """ Metadata embedding. Args: categorical_vocab_sizes: Categorical metadata token vocabulary sizes. d_model: Dimensionality of the embeddings and hidden states. initializer: Initializer for the embeddings. """ def __init__( self, categorical_vocab_sizes: dict[str, int], d_model: int, embeddings_initializer: dict[str, Any], ) -> None: super().__init__() self.E = nn.ModuleDict( {key: nn.Embedding(vocab_size, d_model) for key, vocab_size in categorical_vocab_sizes.items()} ) self.embeddings_initializer = embeddings_initializer self._reset_parameters() def _reset_parameters(self) -> None: for module in self.E.children(): create_initializer(self.embeddings_initializer)(module.weight)
[docs] def forward(self, metadata_tokens_n: dict[str, torch.Tensor]) -> torch.Tensor: """ Args: metadata_token_n: Dictionary of metadata token tensors of shape ``(n,)``. Returns: The metadata embedding tensor of shape ``(n, m, d)``. """ return torch.stack( [self.E[key](metadata_token_n) for key, metadata_token_n in metadata_tokens_n.items()], dim=1, )