Source code for

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from typing import Literal

import lightning.pytorch as pl
import numpy as np
import torch
import torch.distributed as dist
from lightning.pytorch.strategies import DDPStrategy

from import CellariumModel
from import get_rank_and_num_replicas
from import (

[docs] class OnePassMeanVarStd(CellariumModel): """ Calculate the mean, variance, and standard deviation of the data in one pass (epoch) using running sums and running squared sums. **References:** 1. `Algorithms for calculating variance <>`_. Args: var_names_g: The variable names schema for the input data validation. """ def __init__(self, var_names_g: np.ndarray, algorithm: Literal["naive", "shifted_data"] = "naive") -> None: super().__init__() self.var_names_g = var_names_g n_vars = len(self.var_names_g) self.n_vars = n_vars self.algorithm = algorithm self.x_sums: torch.Tensor self.x_squared_sums: torch.Tensor self.x_size: torch.Tensor self.x_shift: torch.Tensor | None self.register_buffer("x_sums", torch.empty(n_vars)) self.register_buffer("x_squared_sums", torch.empty(n_vars)) self.register_buffer("x_size", torch.empty(())) if self.algorithm == "shifted_data": self.register_buffer("x_shift", torch.empty(n_vars)) else: self.register_buffer("x_shift", None) self._dummy_param = torch.nn.Parameter(torch.empty(())) self.reset_parameters() def reset_parameters(self) -> None: self.x_sums.zero_() self.x_squared_sums.zero_() self.x_size.zero_() if self.x_shift is not None: self.x_shift.zero_()
[docs] def forward(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, torch.Tensor | None]: """ Args: x_ng: Gene counts matrix. var_names_g: The list of the variable names in the input data. Returns: An empty dictionary. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) if self.algorithm == "naive": self.x_sums = self.x_sums + x_ng.sum(dim=0) self.x_squared_sums = self.x_squared_sums + (x_ng**2).sum(dim=0) self.x_size = self.x_size + x_ng.shape[0] elif self.algorithm == "shifted_data": assert self.x_shift is not None if (self.x_shift == 0).all(): _, world_size = get_rank_and_num_replicas() if world_size > 1: gathered_x_ng = torch.zeros( world_size * x_ng.shape[0], x_ng.shape[1], dtype=x_ng.dtype, device=x_ng.device ) dist.all_gather_into_tensor(gathered_x_ng, x_ng) x_shift = gathered_x_ng.mean(dim=0) else: x_shift = x_ng.mean(dim=0) self.x_shift = x_shift self.x_sums = self.x_sums + (x_ng - self.x_shift).sum(dim=0) self.x_squared_sums = self.x_squared_sums + ((x_ng - self.x_shift) ** 2).sum(dim=0) self.x_size = self.x_size + x_ng.shape[0] else: raise ValueError(f"Unknown algorithm: {self.algorithm}") return {}
def on_train_start(self, trainer: pl.Trainer) -> None: if trainer.world_size > 1: assert isinstance(trainer.strategy, DDPStrategy), ( "OnePassMeanVarStd requires that the trainer uses the DDP strategy." ) assert trainer.strategy._ddp_kwargs["broadcast_buffers"] is False, ( "OnePassMeanVarStd requires that broadcast_buffers is set to False." ) def on_train_epoch_end(self, trainer: pl.Trainer) -> None: # no need to merge if only one process if trainer.world_size == 1: return # merge the running sums dist.reduce(self.x_sums, dst=0, op=dist.ReduceOp.SUM) dist.reduce(self.x_squared_sums, dst=0, op=dist.ReduceOp.SUM) dist.reduce(self.x_size, dst=0, op=dist.ReduceOp.SUM) @property def mean_g(self) -> torch.Tensor: """ Mean of the data. """ mean_g = self.x_sums / self.x_size if self.algorithm == "shifted_data": assert isinstance(self.x_shift, torch.Tensor) mean_g = mean_g + self.x_shift return mean_g @property def var_g(self) -> torch.Tensor: """ Variance of the data. """ return self.x_squared_sums / self.x_size - (self.x_sums / self.x_size) ** 2 @property def std_g(self) -> torch.Tensor: """ Standard deviation of the data. """ return torch.sqrt(self.var_g)