# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import torch
from torch import nn
from torch.distributions import Binomial
from .randomize import Randomize
[docs]
class BinomialResample(nn.Module):
"""
Binomial resampling of gene counts.
For each count, the parameter to the binomial distribution
is independently and uniformly sampled according to the
bounding parameters, yielding the parameter matrix p_ng.
.. math::
y_{ng} = Binomial(n=x_{ng}, p=p_{ng})
Args:
p_binom_min:
Lower bound on binomial distribution parameter.
p_binom_max:
Upper bound on binomial distribution parameter.
p_apply:
Probability of applying transform to each sample.
"""
def __init__(self, p_binom_min: float, p_binom_max: float, p_apply: float):
super().__init__()
self.p_binom_min = p_binom_min
self.p_binom_max = p_binom_max
self.randomize = Randomize(p_apply)
[docs]
def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Args:
x_ng: Gene counts.
Returns:
Binomially resampled gene counts.
"""
p_binom_ng = torch.empty_like(x_ng).uniform_(self.p_binom_min, self.p_binom_max)
x_aug = Binomial(total_count=x_ng, probs=p_binom_ng).sample()
x_ng = self.randomize(x_aug, x_ng)
return {"x_ng": x_ng}