# Copyright Contributors to the Cellarium project.# SPDX-License-Identifier: BSD-3-ClausefromtypingimportAnyimporttorchfromtorchimportnnfromcellarium.ml.utilities.layersimportcreate_initializer
[docs]classPositionWiseFFN(nn.Module):""" The positionwise feed-forward network. Args: d_ffn: Dimensionality of the inner feed-forward layers. d_model: Dimensionality of the embeddings and hidden states. use_bias: Whether to use bias in the linear transformations. dense1_initializer: Initializer for the first dense layer. dense2_initializer: Initializer for the second dense layer. """def__init__(self,d_ffn:int,d_model:int,use_bias:bool,dense1_initializer:dict[str,Any],dense2_initializer:dict[str,Any],)->None:super().__init__()self.dense1=nn.Linear(d_model,d_ffn,bias=use_bias)self.activation=nn.GELU()self.dense2=nn.Linear(d_ffn,d_model,bias=use_bias)self.dense1_initializer=dense1_initializerself.dense2_initializer=dense2_initializerself._reset_parameters()def_reset_parameters(self)->None:create_initializer(self.dense1_initializer)(self.dense1.weight)ifself.dense1.biasisnotNone:nn.init.zeros_(self.dense1.bias)create_initializer(self.dense2_initializer)(self.dense2.weight)ifself.dense2.biasisnotNone:nn.init.zeros_(self.dense2.bias)
[docs]defforward(self,hidden_state_ncd:torch.Tensor)->torch.Tensor:""" Args: hidden_state_ncd: Hidden state tensor of shape ``(n, c, d)``. Returns: The output hidden state tensor of shape ``(n, c, d)``. """returnself.dense2(self.activation(self.dense1(hidden_state_ncd)))# _ncd