Source code for cellarium.ml.layers.head

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any

import torch
from torch import nn

from cellarium.ml.utilities.layers import create_initializer


[docs] class MultiHeadReadout(nn.Module): """ Multi-head readout. Args: categorical_vocab_sizes: Categorical token vocabulary sizes. d_model: Dimensionality of the embeddings and hidden states. use_bias: Whether to use bias in the linear transformations. output_logits_scale: Multiplier for the output logits. heads_initializer: Initializer for the output linear transformations. """ def __init__( self, categorical_vocab_sizes: dict[str, int], d_model: int, use_bias: bool, output_logits_scale: float, heads_initializer: dict[str, Any], ) -> None: super().__init__() self.W = nn.ModuleDict( {key: nn.Linear(d_model, vocab_size, use_bias) for key, vocab_size in categorical_vocab_sizes.items()} ) self.output_logits_scale = output_logits_scale self.heads_initializer = heads_initializer self._reset_parameters() def _reset_parameters(self) -> None: for module in self.W.children(): create_initializer(self.heads_initializer)(module.weight) if module.bias is not None: nn.init.zeros_(module.bias)
[docs] def forward(self, hidden_state_ncd: torch.Tensor) -> dict[str, torch.Tensor]: """ Args: hidden_state_ncd: Hidden state tensor of shape ``(n, c, d)``. Returns: Dictionary of output logits tensors of shape ``(n, c, vocab_size)``. """ return {key: self.output_logits_scale * self.W[key](hidden_state_ncd) for key in self.W}