# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Callable
from typing import Any
import lightning.pytorch as pl
import numpy as np
import torch
from pyro.nn.module import PyroParam, _unconstrain
from torch.distributions import transform_to
[docs]
class CellariumModel(torch.nn.Module, metaclass=ABCMeta):
"""
Base class for Cellarium ML compatible models.
"""
def __init__(self) -> None:
super(torch.nn.Module, self).__setattr__("_pyro_params", OrderedDict())
super().__init__()
__call__: Callable[..., dict[str, torch.Tensor | None]]
[docs]
@abstractmethod
def reset_parameters(self) -> None:
"""
Reset the model parameters and buffers that were constructed in the __init__ method.
Constructed means methods like ``torch.tensor``, ``torch.empty``, ``torch.zeros``,
``torch.randn``, ``torch.as_tensor``, etc.
If the parameter or buffer was assigned (e.g. from a torch.Tensor passed to the __init__)
then there is no need to reset it.
"""
def __getattr__(self, name: str) -> Any:
if "_pyro_params" in self.__dict__:
_pyro_params = self.__dict__["_pyro_params"]
if name in _pyro_params:
constraint, event_dim = _pyro_params[name]
unconstrained_value = getattr(self, name + "_unconstrained")
constrained_value = transform_to(constraint)(unconstrained_value)
return constrained_value
return super().__getattr__(name)
def __setattr__(self, name: str, value: torch.Tensor | torch.nn.Module | PyroParam) -> None:
if isinstance(value, PyroParam):
# Create a new PyroParam, overwriting any old value.
try:
delattr(self, name)
except AttributeError:
pass
constrained_value, constraint, event_dim = value
self._pyro_params[name] = constraint, event_dim
assert constrained_value is not None
unconstrained_value = _unconstrain(constrained_value, constraint)
super().__setattr__(name + "_unconstrained", unconstrained_value)
return
super().__setattr__(name, value)
def __delattr__(self, name: str) -> None:
if name in self._pyro_params:
delattr(self, name + "_unconstrained")
del self._pyro_params[name]
else:
super().__delattr__(name)
[docs]
class PredictMixin(metaclass=ABCMeta):
"""
Abstract mixin class for models that can perform prediction.
"""
[docs]
@abstractmethod
def predict(self, *args: Any, **kwargs: Any) -> dict[str, np.ndarray | torch.Tensor]:
"""
Perform prediction.
"""
[docs]
class ValidateMixin:
"""
Mixin class for models that can perform validation.
"""
__call__: Callable[..., dict[str, torch.Tensor | None]]
[docs]
def validate(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
batch_idx: int,
*args: Any,
**kwargs: Any,
) -> None:
"""
Default validation method for models. This method logs the validation loss to TensorBoard.
Override this method to customize the validation behavior.
"""
output = self(*args, **kwargs)
loss = output.get("loss")
if loss is not None:
# Logging to TensorBoard by default
pl_module.log("val_loss", loss, sync_dist=True, on_epoch=True)