Source code for cellarium.ml.models.onepass_mean_var_std

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Literal

import lightning.pytorch as pl
import numpy as np
import torch
import torch.distributed as dist
from lightning.pytorch.strategies import DDPStrategy

from cellarium.ml.models.model import CellariumModel
from cellarium.ml.utilities.distributed import get_rank_and_num_replicas
from cellarium.ml.utilities.testing import (
    assert_arrays_equal,
    assert_columns_and_array_lengths_equal,
)


[docs] class OnePassMeanVarStd(CellariumModel): """ Calculate the mean, variance, and standard deviation of the data in one pass (epoch) using running sums and running squared sums. **References:** 1. `Algorithms for calculating variance <https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance>`_. Args: var_names_g: The variable names schema for the input data validation. """ def __init__(self, var_names_g: np.ndarray, algorithm: Literal["naive", "shifted_data"] = "naive") -> None: super().__init__() self.var_names_g = var_names_g n_vars = len(self.var_names_g) self.n_vars = n_vars self.algorithm = algorithm self.x_sums: torch.Tensor self.x_squared_sums: torch.Tensor self.x_size: torch.Tensor self.x_shift: torch.Tensor | None self.register_buffer("x_sums", torch.empty(n_vars)) self.register_buffer("x_squared_sums", torch.empty(n_vars)) self.register_buffer("x_size", torch.empty(())) if self.algorithm == "shifted_data": self.register_buffer("x_shift", torch.empty(n_vars)) else: self.register_buffer("x_shift", None) self._dummy_param = torch.nn.Parameter(torch.empty(())) self.reset_parameters() def reset_parameters(self) -> None: self.x_sums.zero_() self.x_squared_sums.zero_() self.x_size.zero_() if self.x_shift is not None: self.x_shift.zero_() self._dummy_param.data.zero_()
[docs] def forward(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, torch.Tensor | None]: """ Args: x_ng: Gene counts matrix. var_names_g: The list of the variable names in the input data. Returns: An empty dictionary. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) if self.algorithm == "naive": self.x_sums = self.x_sums + x_ng.sum(dim=0) self.x_squared_sums = self.x_squared_sums + (x_ng**2).sum(dim=0) self.x_size = self.x_size + x_ng.shape[0] elif self.algorithm == "shifted_data": assert self.x_shift is not None if (self.x_shift == 0).all(): _, world_size = get_rank_and_num_replicas() if world_size > 1: gathered_x_ng = torch.zeros( world_size * x_ng.shape[0], x_ng.shape[1], dtype=x_ng.dtype, device=x_ng.device ) dist.all_gather_into_tensor(gathered_x_ng, x_ng) x_shift = gathered_x_ng.mean(dim=0) else: x_shift = x_ng.mean(dim=0) self.x_shift = x_shift self.x_sums = self.x_sums + (x_ng - self.x_shift).sum(dim=0) self.x_squared_sums = self.x_squared_sums + ((x_ng - self.x_shift) ** 2).sum(dim=0) self.x_size = self.x_size + x_ng.shape[0] else: raise ValueError(f"Unknown algorithm: {self.algorithm}") return {}
def on_train_start(self, trainer: pl.Trainer) -> None: if trainer.world_size > 1: assert isinstance( trainer.strategy, DDPStrategy ), "OnePassMeanVarStd requires that the trainer uses the DDP strategy." assert ( trainer.strategy._ddp_kwargs["broadcast_buffers"] is False ), "OnePassMeanVarStd requires that broadcast_buffers is set to False." def on_train_epoch_end(self, trainer: pl.Trainer) -> None: # no need to merge if only one process if trainer.world_size == 1: return # merge the running sums dist.reduce(self.x_sums, dst=0, op=dist.ReduceOp.SUM) dist.reduce(self.x_squared_sums, dst=0, op=dist.ReduceOp.SUM) dist.reduce(self.x_size, dst=0, op=dist.ReduceOp.SUM) @property def mean_g(self) -> torch.Tensor: """ Mean of the data. """ mean_g = self.x_sums / self.x_size if self.algorithm == "shifted_data": mean_g = mean_g + self.x_shift return mean_g @property def var_g(self) -> torch.Tensor: """ Variance of the data. """ return self.x_squared_sums / self.x_size - (self.x_sums / self.x_size) ** 2 @property def std_g(self) -> torch.Tensor: """ Standard deviation of the data. """ return torch.sqrt(self.var_g)