# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
from typing import Any, Type
import torch
[docs]
class DressedLayer(torch.nn.Module):
"""
Small block comprising a `~torch.nn.Module` with optional batch/layer normalization
and configurable activation and dropout.
Similar to
torch.nn.Sequential(
layer,
optional batch normalization,
optional layer normalization,
optional activation,
optional dropout,
)
but the `layer` can take multiple inputs.
Note that batch normalization and layer normalization are mutually exclusive options.
Args:
layer: single layer `torch.nn.Module`, such as an instance of `torch.nn.Linear`
use_batch_norm: whether to use batch normalization
use_layer_norm: whether to use layer normalization
activation_fn: the activation function to use
dropout_rate: dropout rate, can be zero
"""
def __init__(
self,
layer: torch.nn.Module,
use_batch_norm: bool = False,
batch_norm_kwargs: dict | None = None,
use_layer_norm: bool = False,
layer_norm_kwargs: dict | None = None,
activation_fn: Type[torch.nn.Module] | None = torch.nn.ReLU,
dropout_rate: float = 0,
):
if batch_norm_kwargs is None:
batch_norm_kwargs = {"momentum": 0.01, "eps": 0.001}
if layer_norm_kwargs is None:
layer_norm_kwargs = {"elementwise_affine": False}
assert not (use_batch_norm and use_layer_norm), "Cannot use both batch and layer normalization."
super().__init__()
out_features = getattr(layer, "out_features", None)
if out_features is None:
raise ValueError(f"attempted to use {layer} in DressedLayer, but it does not define out_features")
assert isinstance(out_features, int), "The layer must have an `out_features` attribute of type `int`."
batch_norm = torch.nn.BatchNorm1d(out_features, **batch_norm_kwargs) if use_batch_norm else None
layer_norm = torch.nn.LayerNorm(out_features, **layer_norm_kwargs) if use_layer_norm else None
activation = activation_fn() if (activation_fn is not None) else None
dropout = torch.nn.Dropout(p=dropout_rate) if dropout_rate > 0 else None
module_list = [batch_norm, layer_norm, activation, dropout]
self.layer = layer
self.dressing = torch.nn.Sequential(*[m for m in module_list if m is not None])
[docs]
def forward(self, *args, **kwargs) -> torch.Tensor:
"""
Computes the forward pass of the block.
"""
x = self.layer(*args, **kwargs)
return self.dressing(x)
[docs]
class FullyConnectedLinear(torch.nn.Module):
"""
Fully connected block of layers (can be empty).
Args:
in_features: The dimensionality of the input
out_features: The dimensionality of the output
n_hidden: A list of sizes of torch.nn.Linear hidden layers
dressing_init_kwargs: A dictionary of keyword arguments to pass ``DressedLayer``'s constructor
bias: True to include a bias in the final linear layer
"""
def __init__(
self,
in_features: int,
out_features: int,
n_hidden: list[int],
dressing_init_kwargs: dict[str, Any] | None = None,
bias: bool = False,
):
super().__init__()
if dressing_init_kwargs is None:
dressing_init_kwargs = {}
module_list = torch.nn.ModuleList()
layer_size = in_features
if len(n_hidden) > 0:
for n_in, n_out in zip([in_features] + n_hidden, n_hidden):
module_list.append(
DressedLayer(
torch.nn.Linear(in_features=n_in, out_features=n_out, bias=True),
**dressing_init_kwargs,
)
)
layer_size = n_out
module_list.append(torch.nn.Linear(layer_size, out_features, bias=bias))
self.module_list = module_list
self.out_features = out_features
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_ = x
for layer in self.module_list:
x_ = layer(x_)
return x_