# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import numpy as np
import torch
from torch import nn
from cellarium.ml.utilities.testing import (
assert_arrays_equal,
assert_columns_and_array_lengths_equal,
assert_nonnegative,
)
[docs]
class ZScore(nn.Module):
"""
ZScore gene counts with mean and standard deviation.
.. math::
y_{ng} = \\frac{x_{ng} - \\mathrm{mean}_g}{\\mathrm{std}_g + \\mathrm{eps}}
Args:
mean_g:
Means for each gene.
std_g:
Standard deviations for each gene.
var_names_g:
The variable names schema for the input data validation.
eps:
A value added to the denominator for numerical stability.
"""
def __init__(
self,
mean_g: torch.Tensor,
std_g: torch.Tensor,
var_names_g: np.ndarray,
eps: float = 1e-6,
) -> None:
super().__init__()
self.mean_g: torch.Tensor
self.std_g: torch.Tensor
self.register_buffer("mean_g", mean_g)
self.register_buffer("std_g", std_g)
self.var_names_g = var_names_g
assert_nonnegative("eps", eps)
self.eps = eps
[docs]
def forward(
self,
x_ng: torch.Tensor,
var_names_g: np.ndarray,
) -> dict[str, torch.Tensor]:
"""
.. note::
When used with :class:`~cellarium.ml.core.CellariumModule` or :class:`~cellarium.ml.core.CellariumPipeline`,
``x_ng`` key in the input dictionary will be overwritten with the z-scored values.
Args:
x_ng:
Gene counts.
var_names_g:
The list of the variable names in the input data. If ``None``, no validation is performed.
Returns:
A dictionary with the following keys:
- ``x_ng``: The z-scored gene counts.
"""
assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g)
assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g)
x_ng = (x_ng - self.mean_g) / (self.std_g + self.eps)
return {"x_ng": x_ng}
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(mean_g={self.mean_g}, std_g={self.std_g}, "
f"var_names_g={self.var_names_g}), eps={self.eps}"
)