Source code for cellarium.ml.distributions.negative_binomial

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from numbers import Number

import torch
from pyro.distributions import TorchDistribution, constraints
from torch.distributions.utils import broadcast_all, lazy_property


[docs] class NegativeBinomial(TorchDistribution): """Negative binomial distribution. Args: mu: Mean of the distribution. theta: Inverse dispersion. """ THETA_THRESHOLD_STIRLING_SWITCH = 200 arg_constraints = {"mu": constraints.greater_than_eq(0), "theta": constraints.greater_than_eq(0)} support = constraints.nonnegative_integer def __init__(self, mu: torch.Tensor, theta: torch.Tensor, validate_args: bool | None = None) -> None: self.mu, self.theta = broadcast_all(mu, theta) if isinstance(mu, Number) and isinstance(theta, Number): batch_shape = torch.Size() else: batch_shape = self.mu.size() super().__init__(batch_shape, validate_args=validate_args) @property def mean(self) -> torch.Tensor: return self.mu @property def variance(self) -> torch.Tensor: return (self.mu + (self.mu**2) / self.theta).masked_fill(self.theta == 0, 0) @lazy_property def _gamma(self) -> torch.distributions.Gamma: # Note we avoid validating because self.theta can be zero. return torch.distributions.Gamma( concentration=self.theta, rate=(self.theta / self.mu).masked_fill(self.theta == 0, 1.0), validate_args=False, ) def sample(self, sample_shape=torch.Size()) -> torch.Tensor: with torch.no_grad(): rate = self._gamma.sample(sample_shape=sample_shape) return torch.poisson(rate) def log_prob(self, value: torch.Tensor) -> torch.Tensor: # type: ignore[override] if self._validate_args: self._validate_sample(value) # Original implementation from scVI: # # log_theta_mu_eps = torch.log(self.theta + self.mu + self.eps) # return ( # self.theta * (torch.log(self.theta + self.eps) - log_theta_mu_eps) # + value * (torch.log(self.mu + self.eps) - log_theta_mu_eps) # + torch.lgamma(value + self.theta) # - torch.lgamma(self.theta) # - torch.lgamma(value + 1) # ) delta = torch.where( self.theta > self.THETA_THRESHOLD_STIRLING_SWITCH, (value + self.theta - 0.5) * torch.log1p(value / self.theta) - value, (value + self.theta).lgamma() - self.theta.lgamma() - torch.xlogy(value, self.theta), ) # The case self.theta == 0 and value == 0 has probability 1. # The case self.theta == 0 and value != 0 has probability 0. return ( (delta - (value + self.theta) * torch.log1p(self.mu / self.theta)).masked_fill(self.theta == 0, 0) - (value + 1).lgamma() + torch.xlogy(value, self.mu) )