Source code for cellarium.ml.transforms.gaussian_noise

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause


import torch
from torch import nn

from .randomize import Randomize


[docs] class GaussianNoise(nn.Module): """ Adds Gaussian noise to gene counts. For each count, Gaussian sigma is independently and uniformly sampled according to the bounding parameters, yielding the sigma matrix sigma_ng. .. math:: y_{ng} = x_{ng} + N(0, \\sigma_{ng}) Args: sigma_min: Lower bound on Gaussian sigma parameter. sigma_max: Upper bound on Gaussian sigma parameter. p_apply: Probability of applying transform to each sample. """ def __init__(self, sigma_min, sigma_max, p_apply): super().__init__() self.sigma_min = sigma_min self.sigma_max = sigma_max self.randomize = Randomize(p_apply)
[docs] def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]: """ Args: x_ng: Gene counts (log-transformed). Returns: Gene counts with added Gaussian noise. """ sigma_ng = torch.empty_like(x_ng).uniform_(self.sigma_min, self.sigma_max) x_aug = x_ng + torch.normal(mean=0, std=sigma_ng) x_ng = self.randomize(x_aug, x_ng) return {"x_ng": x_ng}