Source code for cellarium.ml.layers.normadd

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Callable

import torch
from torch import nn


[docs] class NormAdd(nn.Module): """ Pre-norm layer where the layer normalization is applied before the sublayer. Args: norm_shape: The shape of the layer normalization. dropout_p: Dropout probability. use_bias: Whether to use bias in the layer normalization. """ def __init__(self, norm_shape: int, dropout_p: float, use_bias: bool) -> None: super().__init__() self.dropout = nn.Dropout(dropout_p) self.ln = nn.LayerNorm(norm_shape, bias=use_bias) self._reset_parameters() def _reset_parameters(self) -> None: self.ln.reset_parameters()
[docs] def forward( self, hidden_state_ncd: torch.Tensor, sublayer: Callable[[torch.Tensor], torch.Tensor], ) -> torch.Tensor: """ Args: hidden_state_ncd: Hidden state tensor of shape ``(n, c, d)``. sublayer: Sublayer function. Returns: The output hidden state tensor of shape ``(n, c, d)``. """ return hidden_state_ncd + self.dropout(sublayer(self.ln(hidden_state_ncd))) # _ncd