Source code for cellarium.ml.layers.ffn

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Any

import torch
from torch import nn

from cellarium.ml.utilities.layers import create_initializer


[docs] class PositionWiseFFN(nn.Module): """ The positionwise feed-forward network. Args: d_ffn: Dimensionality of the inner feed-forward layers. d_model: Dimensionality of the embeddings and hidden states. use_bias: Whether to use bias in the linear transformations. dense1_initializer: Initializer for the first dense layer. dense2_initializer: Initializer for the second dense layer. """ def __init__( self, d_ffn: int, d_model: int, use_bias: bool, dense1_initializer: dict[str, Any], dense2_initializer: dict[str, Any], ) -> None: super().__init__() self.dense1 = nn.Linear(d_model, d_ffn, bias=use_bias) self.activation = nn.GELU() self.dense2 = nn.Linear(d_ffn, d_model, bias=use_bias) self.dense1_initializer = dense1_initializer self.dense2_initializer = dense2_initializer self._reset_parameters() def _reset_parameters(self) -> None: create_initializer(self.dense1_initializer)(self.dense1.weight) if self.dense1.bias is not None: nn.init.zeros_(self.dense1.bias) create_initializer(self.dense2_initializer)(self.dense2.weight) if self.dense2.bias is not None: nn.init.zeros_(self.dense2.bias)
[docs] def forward(self, hidden_state_ncd: torch.Tensor) -> torch.Tensor: """ Args: hidden_state_ncd: Hidden state tensor of shape ``(n, c, d)``. Returns: The output hidden state tensor of shape ``(n, c, d)``. """ return self.dense2(self.activation(self.dense1(hidden_state_ncd))) # _ncd