# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from collections.abc import Iterable
from typing import Any
import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRSchedulerConfig
from cellarium.ml.core.datamodule import CellariumAnnDataDataModule, collate_fn
from cellarium.ml.core.pipeline import CellariumPipeline
from cellarium.ml.models import CellariumModel
from cellarium.ml.utilities.core import FunctionComposer, copy_module
[docs]
class CellariumModule(pl.LightningModule):
"""
``CellariumModule`` organizes code into following sections:
* :attr:`cpu_transforms`: A list of transforms to apply to the input data as part of the dataloader on CPU.
* :attr:`transforms`: A list of transforms to apply to the input data before passing it to the model.
* :attr:`module_pipeline`: A :class:`CellariumPipeline` to apply all transforms, minus the CPU transforms
if they are handled by a :class:`CellariumAnnDataDataModule`, and the model.
* :attr:`model`: A :class:`CellariumModel` to train with :meth:`training_step` method and epoch end hooks.
* :attr:`optim_fn` and :attr:`optim_kwargs`: A Pytorch optimizer class and its keyword arguments.
* :attr:`scheduler_fn` and :attr:`scheduler_kwargs`: A Pytorch lr scheduler class and its
keyword arguments.
Args:
cpu_transforms:
A list of transforms to apply to the input data as part of the dataloader on CPU.
These transforms get applied before other ``transforms``.
If ``None``, no transforms are applied as part of the dataloader.
transforms:
A list of transforms to apply to the input data before passing it to the model.
If ``None``, no transforms are applied.
model:
A :class:`CellariumModel` to train.
optim_fn:
A Pytorch optimizer class, e.g., :class:`~torch.optim.Adam`. If ``None``,
no optimizer is used.
optim_kwargs:
Keyword arguments for optimiser.
scheduler_fn:
A Pytorch lr scheduler class, e.g., :class:`~torch.optim.lr_scheduler.CosineAnnealingLR`.
scheduler_kwargs:
Keyword arguments for lr scheduler.
is_initialized:
Whether the model has been initialized. This is set to ``False`` by default under the assumption that
``torch.device("meta")`` context was used and is set to ``True`` after
the first call to :meth:`configure_model`.
"""
def __init__(
self,
cpu_transforms: Iterable[torch.nn.Module] | None = None,
transforms: Iterable[torch.nn.Module] | None = None,
model: CellariumModel | None = None,
optim_fn: type[torch.optim.Optimizer] | None = None,
optim_kwargs: dict[str, Any] | None = None,
scheduler_fn: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_kwargs: dict[str, Any] | None = None,
is_initialized: bool = False,
) -> None:
super().__init__()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Attribute 'model' is an instance of `nn.Module`")
self.save_hyperparameters(logger=False)
self.pipeline: CellariumPipeline | None = None
self._cpu_transforms_in_module_pipeline: bool = True
if optim_fn is None:
# Starting from PyTorch Lightning 2.3, automatic optimization doesn't allow to return None
# from the training_step during distributed training. https://github.com/Lightning-AI/pytorch-lightning/pull/19918
# Thus, we need to use manual optimization for the No Optimizer case.
self.automatic_optimization = False
def __repr__(self) -> str:
if not self._cpu_transforms_in_module_pipeline:
cpu_trans_str = str(self.cpu_transforms).replace("\n", "\n ")
trans_str = str(self.transforms).replace("\n", "\n ")
repr = (
f"{self.__class__.__name__}("
+ (
f"\n [ dataloader CPU transforms = \n {cpu_trans_str}\n ]"
if not self._cpu_transforms_in_module_pipeline
else ""
)
+ f"\n transforms = {trans_str}"
+ f"\n model = {self.model}"
+ "\n)"
)
else:
repr = f"{self.__class__.__name__}(pipeline = {self.module_pipeline})"
return repr
@property
def model(self) -> CellariumModel:
"""The model"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
return self.pipeline[-1]
@property
def transforms(self) -> CellariumPipeline:
"""The transforms pipeline"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
return self.pipeline[self._num_cpu_transforms : -1]
@property
def cpu_transforms(self) -> CellariumPipeline:
"""The CPU transforms pipeline to be applied by the dataloader"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
return self.pipeline[: self._num_cpu_transforms]
@property
def _num_cpu_transforms(self) -> int:
return 0 if self.hparams["cpu_transforms"] is None else len(self.hparams["cpu_transforms"])
@property
def module_pipeline(self) -> CellariumPipeline:
"""The pipeline applied by :meth:`training_step`, :meth:`validation_step`, and :meth:`forward`"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
return self.pipeline if self._cpu_transforms_in_module_pipeline else self.pipeline[self._num_cpu_transforms :]
[docs]
def training_step( # type: ignore[override]
self, batch: dict[str, np.ndarray | torch.Tensor], batch_idx: int
) -> torch.Tensor | None:
"""
Forward pass for training step.
Args:
batch:
A dictionary containing the batch data.
batch_idx:
The index of the batch.
Returns:
Loss tensor or ``None`` if no loss.
"""
if self.module_pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
output = self.module_pipeline(batch)
loss = output.get("loss")
if loss is not None:
# Logging to TensorBoard by default
self.log("train_loss", loss, sync_dist=True)
if not self.automatic_optimization:
# Note, that running .step() is necessary for incrementing the global step even though no backpropagation
# is performed.
no_optimizer = self.optimizers()
assert isinstance(no_optimizer, pl.core.optimizer.LightningOptimizer)
no_optimizer.step()
return loss
[docs]
def forward(self, batch: dict[str, np.ndarray | torch.Tensor]) -> dict[str, np.ndarray | torch.Tensor]:
"""
Forward pass for inference step.
Args:
batch: A dictionary containing the batch data.
Returns:
A dictionary containing the batch data and inference outputs.
"""
if self.module_pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
return self.module_pipeline.predict(batch)
[docs]
def validation_step(self, batch: dict[str, Any], batch_idx: int) -> None:
"""
Forward pass for validation step.
Args:
batch:
A dictionary containing the batch data.
batch_idx:
The index of the batch.
Returns:
None
"""
if self.module_pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
self.module_pipeline.validate(batch)
[docs]
def on_train_epoch_start(self) -> None:
"""
Calls the ``set_epoch`` method on the iterable dataset of the given dataloader.
If the dataset is ``IterableDataset`` and has ``set_epoch`` method defined, then
``set_epoch`` must be called at the beginning of every epoch to ensure shuffling
applies a new ordering. This has no effect if shuffling is off.
"""
# dataloader is wrapped in a combined loader and can be accessed via
# flattened property which returns a list of dataloaders
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.utilities.combined_loader.html
combined_loader = self.trainer.fit_loop._combined_loader
assert combined_loader is not None
dataloaders = combined_loader.flattened
for dataloader in dataloaders:
dataset = dataloader.dataset
set_epoch = getattr(dataset, "set_epoch", None)
if callable(set_epoch):
set_epoch(self.current_epoch)
[docs]
def on_train_start(self) -> None:
"""
Calls the ``on_train_start`` method on the :attr:`model` attribute.
If the :attr:`model` attribute has ``on_train_start`` method defined, then
``on_train_start`` must be called at the beginning of training.
"""
on_train_start = getattr(self.model, "on_train_start", None)
if callable(on_train_start):
on_train_start(self.trainer)
[docs]
def on_train_epoch_end(self) -> None:
"""
Calls the ``on_train_epoch_end`` method on the :attr:`model` attribute.
If the :attr:`model` attribute has ``on_train_epoch_end`` method defined, then
``on_train_epoch_end`` must be called at the end of every epoch.
"""
on_train_epoch_end = getattr(self.model, "on_train_epoch_end", None)
if callable(on_train_epoch_end):
on_train_epoch_end(self.trainer)
[docs]
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
"""
Calls the ``on_train_batch_end`` method on the module.
"""
on_train_batch_end = getattr(self.model, "on_train_batch_end", None)
if callable(on_train_batch_end):
on_train_batch_end(self.trainer)