# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import warnings
from collections.abc import Iterable
from typing import Any
import lightning.pytorch as pl
import numpy as np
import torch
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRSchedulerConfig
from cellarium.ml.core.datamodule import CellariumAnnDataDataModule
from cellarium.ml.core.pipeline import CellariumPipeline
from cellarium.ml.models import CellariumModel
from cellarium.ml.utilities.core import FunctionComposer, copy_module
[docs]
class CellariumModule(pl.LightningModule):
"""
``CellariumModule`` organizes code into following sections:
* :attr:`cpu_transforms`: A list of transforms to apply to the input data as part of the dataloader on CPU.
* :attr:`transforms`: A list of transforms to apply to the input data before passing it to the model.
* :attr:`module_pipeline`: A :class:`CellariumPipeline` to apply all transforms, minus the CPU transforms
if they are handled by a :class:`CellariumAnnDataDataModule`, and the model.
* :attr:`model`: A :class:`CellariumModel` to train with :meth:`training_step` method and epoch end hooks.
* :attr:`optim_fn` and :attr:`optim_kwargs`: A Pytorch optimizer class and its keyword arguments.
* :attr:`scheduler_fn` and :attr:`scheduler_kwargs`: A Pytorch lr scheduler class and its
keyword arguments.
Args:
cpu_transforms:
A list of transforms to apply to the input data as part of the dataloader on CPU.
These transforms get applied before other ``transforms``.
If ``None``, no transforms are applied as part of the dataloader.
transforms:
A list of transforms to apply to the input data before passing it to the model.
If ``None``, no transforms are applied.
model:
A :class:`CellariumModel` to train.
optim_fn:
A Pytorch optimizer class, e.g., :class:`~torch.optim.Adam`. If ``None``,
no optimizer is used.
optim_kwargs:
Keyword arguments for optimiser.
scheduler_fn:
A Pytorch lr scheduler class, e.g., :class:`~torch.optim.lr_scheduler.CosineAnnealingLR`.
scheduler_kwargs:
Keyword arguments for lr scheduler.
is_initialized:
Whether the model has been initialized. This is set to ``False`` by default under the assumption that
``torch.device("meta")`` context was used and is set to ``True`` after
the first call to :meth:`configure_model`.
"""
def __init__(
self,
cpu_transforms: Iterable[torch.nn.Module] | None = None,
transforms: Iterable[torch.nn.Module] | None = None,
model: CellariumModel | None = None,
optim_fn: type[torch.optim.Optimizer] | None = None,
optim_kwargs: dict[str, Any] | None = None,
scheduler_fn: type[torch.optim.lr_scheduler.LRScheduler] | None = None,
scheduler_kwargs: dict[str, Any] | None = None,
is_initialized: bool = False,
) -> None:
super().__init__()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Attribute 'model' is an instance of `nn.Module`")
self.save_hyperparameters(logger=False)
self.pipeline: CellariumPipeline | None = None
self._cpu_transforms_in_module_pipeline: bool = True
if optim_fn is None:
# Starting from PyTorch Lightning 2.3, automatic optimization doesn't allow to return None
# from the training_step during distributed training. https://github.com/Lightning-AI/pytorch-lightning/pull/19918
# Thus, we need to use manual optimization for the No Optimizer case.
self.automatic_optimization = False
def __repr__(self) -> str:
if not self._cpu_transforms_in_module_pipeline:
cpu_trans_str = str(self.cpu_transforms).replace("\n", "\n ")
trans_str = str(self.transforms).replace("\n", "\n ")
repr = (
f"{self.__class__.__name__}("
+ (
f"\n [ dataloader CPU transforms = \n {cpu_trans_str}\n ]"
if not self._cpu_transforms_in_module_pipeline
else ""
)
+ f"\n transforms = {trans_str}"
+ f"\n model = {self.model}"
+ "\n)"
)
else:
repr = f"{self.__class__.__name__}(pipeline = {self.module_pipeline})"
return repr
@property
def model(self) -> CellariumModel:
"""The model"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
assert isinstance(model := self.pipeline[-1], CellariumModel)
return model
@property
def transforms(self) -> CellariumPipeline:
"""The transforms pipeline"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
assert isinstance(transforms := self.pipeline[self._num_cpu_transforms : -1], CellariumPipeline)
return transforms
@property
def cpu_transforms(self) -> CellariumPipeline:
"""The CPU transforms pipeline to be applied by the dataloader"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
assert isinstance(cpu_transforms := self.pipeline[: self._num_cpu_transforms], CellariumPipeline)
return cpu_transforms
@property
def _num_cpu_transforms(self) -> int:
return 0 if self.hparams["cpu_transforms"] is None else len(self.hparams["cpu_transforms"])
@property
def module_pipeline(self) -> CellariumPipeline:
"""The pipeline applied by :meth:`training_step`, :meth:`validation_step`, and :meth:`forward`"""
if self.pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
if self._cpu_transforms_in_module_pipeline:
return self.pipeline
else:
assert isinstance(module_pipeline := self.pipeline[self._num_cpu_transforms :], CellariumPipeline)
return module_pipeline
[docs]
def training_step( # type: ignore[override]
self, batch: dict[str, dict[str, np.ndarray | torch.Tensor] | np.ndarray | torch.Tensor], batch_idx: int
) -> torch.Tensor | None:
"""
Forward pass for training step.
Args:
batch:
A dictionary containing the batch data.
batch_idx:
The index of the batch.
Returns:
Loss tensor or ``None`` if no loss.
"""
if self.module_pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
output = self.module_pipeline(batch)
loss = output.get("loss")
if loss is not None:
# Logging to TensorBoard by default
self.log("train_loss", loss, sync_dist=True)
if not self.automatic_optimization:
# Note, that running .step() is necessary for incrementing the global step even though no backpropagation
# is performed.
no_optimizer = self.optimizers()
assert isinstance(no_optimizer, pl.core.optimizer.LightningOptimizer)
no_optimizer.step()
return loss
[docs]
def forward(
self, batch: dict[str, dict[str, np.ndarray | torch.Tensor] | np.ndarray | torch.Tensor]
) -> dict[str, dict[str, np.ndarray | torch.Tensor] | np.ndarray | torch.Tensor]:
"""
Forward pass for inference step.
Args:
batch: A dictionary containing the batch data.
Returns:
A dictionary containing the batch data and inference outputs.
"""
if self.module_pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
return self.module_pipeline.predict(batch)
[docs]
def validation_step(self, batch: dict[str, Any], batch_idx: int) -> None:
"""
Forward pass for validation step.
Args:
batch:
A dictionary containing the batch data.
batch_idx:
The index of the batch.
Returns:
None
"""
if self.module_pipeline is None:
raise RuntimeError("The model is not configured. Call `configure_model` before accessing the model.")
batch["pl_module"] = self
batch["trainer"] = self.trainer
batch["batch_idx"] = batch_idx
self.module_pipeline.validate(batch)
[docs]
def on_train_epoch_start(self) -> None:
"""
Calls the ``set_epoch`` method on the iterable dataset of the given dataloader.
If the dataset is ``IterableDataset`` and has ``set_epoch`` method defined, then
``set_epoch`` must be called at the beginning of every epoch to ensure shuffling
applies a new ordering. This has no effect if shuffling is off.
"""
# dataloader is wrapped in a combined loader and can be accessed via
# flattened property which returns a list of dataloaders
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.utilities.combined_loader.html
combined_loader = self.trainer.fit_loop._combined_loader
assert combined_loader is not None
dataloaders = combined_loader.flattened
for dataloader in dataloaders:
dataset = dataloader.dataset
set_epoch = getattr(dataset, "set_epoch", None)
if callable(set_epoch):
set_epoch(self.current_epoch)
[docs]
def on_train_start(self) -> None:
"""
Calls the ``on_train_start`` method on the :attr:`model` attribute.
If the :attr:`model` attribute has ``on_train_start`` method defined, then
``on_train_start`` must be called at the beginning of training.
"""
on_train_start = getattr(self.model, "on_train_start", None)
if callable(on_train_start):
on_train_start(self.trainer)
[docs]
def on_train_epoch_end(self) -> None:
"""
Calls the ``set_resume_step`` method on the iterable dataset of the given dataloader.
If the dataset is ``IterableDataset`` and has ``set_resume_step`` method defined, then
``set_resume_step`` must be called at the end of every epoch to ensure that the dataset
is in the correct state for resuming training.
Calls the ``on_train_epoch_end`` method on the :attr:`model` attribute.
If the :attr:`model` attribute has ``on_train_epoch_end`` method defined, then
``on_train_epoch_end`` must be called at the end of every epoch.
"""
combined_loader = self.trainer.fit_loop._combined_loader
assert combined_loader is not None
dataloaders = combined_loader.flattened
for dataloader in dataloaders:
dataset = dataloader.dataset
set_resume_step = getattr(dataset, "set_resume_step", None)
if callable(set_resume_step):
set_resume_step(None)
on_train_epoch_end = getattr(self.model, "on_train_epoch_end", None)
if callable(on_train_epoch_end):
on_train_epoch_end(self.trainer)
[docs]
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int) -> None:
"""
Calls the ``on_train_batch_end`` method on the module.
"""
on_train_batch_end = getattr(self.model, "on_train_batch_end", None)
if callable(on_train_batch_end):
on_train_batch_end(self.trainer)
def move_cpu_transforms_to_dataloader(self) -> None:
if not self._cpu_transforms_in_module_pipeline:
warnings.warn(
"The CPU transforms are already moved to the dataloader's collate_fn. Skipping the move operation.",
UserWarning,
)
return
if self._trainer is not None:
if hasattr(self.trainer, "datamodule"):
if isinstance(self.trainer.datamodule, CellariumAnnDataDataModule):
self._cpu_transforms_in_module_pipeline = False
self.trainer.datamodule.collate_fn = FunctionComposer(
first_applied=self.trainer.datamodule.collate_fn,
second_applied=self.cpu_transforms,
)
def setup(self, stage: str) -> None:
# move the cpu_transforms to the dataloader's collate_fn if the dataloader is going to apply them
if self.pipeline is not None:
self.move_cpu_transforms_to_dataloader()
def teardown(self, stage: str) -> None:
# move the cpu_transforms back to the module_pipeline from dataloader's collate_fn
if not self._cpu_transforms_in_module_pipeline:
self.trainer.datamodule.collate_fn = self.trainer.datamodule.collate_fn.first_applied # type: ignore[attr-defined]
self._cpu_transforms_in_module_pipeline = True
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
fit_loop = self.trainer.fit_loop
epoch_loop = fit_loop.epoch_loop
batch_progress = epoch_loop.batch_progress
if batch_progress.current.completed < batch_progress.current.processed: # type: ignore[attr-defined]
# Checkpointing is done before these attributes are updated. So, we need to update them manually.
checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["total"]["completed"] += 1
checkpoint["loops"]["fit_loop"]["epoch_loop.batch_progress"]["current"]["completed"] += 1
if not epoch_loop._should_accumulate():
checkpoint["loops"]["fit_loop"]["epoch_loop.state_dict"]["_batches_that_stepped"] += 1
if batch_progress.is_last_batch:
checkpoint["loops"]["fit_loop"]["epoch_progress"]["total"]["processed"] += 1
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["processed"] += 1
checkpoint["loops"]["fit_loop"]["epoch_progress"]["total"]["completed"] += 1
checkpoint["loops"]["fit_loop"]["epoch_progress"]["current"]["completed"] += 1
checkpoint["CellariumAnnDataDataModule"]["epoch"] += 1