Source code for cellarium.ml.models.model

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Callable
from typing import Any

import numpy as np
import torch
from pyro.nn.module import PyroParam, _unconstrain
from torch.distributions import transform_to


[docs] class CellariumModel(torch.nn.Module, metaclass=ABCMeta): """ Base class for Cellarium ML compatible models. """ def __init__(self) -> None: super(torch.nn.Module, self).__setattr__("_pyro_params", OrderedDict()) super().__init__() __call__: Callable[..., torch.Tensor | None]
[docs] @abstractmethod def reset_parameters(self) -> None: """ Reset the model parameters and buffers that were constructed in the __init__ method. Constructed means methods like ``torch.tensor``, ``torch.empty``, ``torch.zeros``, ``torch.randn``, ``torch.as_tensor``, etc. If the parameter or buffer was assigned (e.g. from a torch.Tensor passed to the __init__) then there is no need to reset it. """
def __getattr__(self, name: str) -> Any: if "_pyro_params" in self.__dict__: _pyro_params = self.__dict__["_pyro_params"] if name in _pyro_params: constraint, event_dim = _pyro_params[name] unconstrained_value = getattr(self, name + "_unconstrained") constrained_value = transform_to(constraint)(unconstrained_value) return constrained_value return super().__getattr__(name) def __setattr__(self, name: str, value: torch.Tensor | torch.nn.Module | PyroParam) -> None: if isinstance(value, PyroParam): # Create a new PyroParam, overwriting any old value. try: delattr(self, name) except AttributeError: pass constrained_value, constraint, event_dim = value self._pyro_params[name] = constraint, event_dim assert constrained_value is not None unconstrained_value = _unconstrain(constrained_value, constraint) super().__setattr__(name + "_unconstrained", unconstrained_value) return super().__setattr__(name, value) def __delattr__(self, name: str) -> None: if name in self._pyro_params: delattr(self, name + "_unconstrained") del self._pyro_params[name] else: super().__delattr__(name)
[docs] class PredictMixin(metaclass=ABCMeta): """ Abstract mixin class for models that can perform prediction. """
[docs] @abstractmethod def predict(self, *args: Any, **kwargs: Any) -> dict[str, np.ndarray | torch.Tensor]: """ Perform prediction. """
[docs] class ValidateMixin(metaclass=ABCMeta): """ Abstract mixin class for models that can perform validation. """
[docs] @abstractmethod def validate(self, *args: Any, **kwargs: Any) -> None: """ Perform validation. """