Source code for cellarium.ml.models.probabilistic_pca

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause


from typing import Literal

import numpy as np
import pyro
import pyro.distributions as dist
import torch
from pyro.nn.module import PyroParam, _unconstrain
from torch.distributions import constraints

from cellarium.ml.models.model import CellariumModel, PredictMixin
from cellarium.ml.utilities.testing import (
    assert_arrays_equal,
    assert_columns_and_array_lengths_equal,
)


[docs] class ProbabilisticPCA(CellariumModel, PredictMixin): """ Probabilistic PCA implemented in Pyro. Two flavors of probabilistic PCA are available - marginalized pPCA [1] and linear VAE [2]. **References:** 1. `Probabilistic Principal Component Analysis (Tipping et al.) <https://www.robots.ox.ac.uk/~cvrg/hilary2006/ppca.pdf>`_. 2. `Understanding Posterior Collapse in Generative Latent Variable Models (Lucas et al.) <https://openreview.net/pdf?id=r1xaVLUYuE>`_. Args: n_obs: Number of cells. var_names_g: The variable names schema for the input data validation. n_components: Number of principal components. ppca_flavor: Type of the PPCA model. Has to be one of `marginalized` or `linear_vae`. mean_g: Mean gene expression of the input data. If ``None`` then the mean is set to a learnable parameter. W_init_scale: Scale of the random initialization of the `W_kg` parameter. sigma_init_scale: Initialization value of the `sigma` parameter. seed: Random seed used to initialize parameters. """ def __init__( self, n_obs: int, var_names_g: np.ndarray, n_components: int, ppca_flavor: Literal["marginalized", "linear_vae"], mean_g: torch.Tensor | None = None, W_init_scale: float = 1.0, sigma_init_scale: float = 1.0, seed: int = 0, ): super().__init__() self.n_obs = n_obs self.var_names_g = var_names_g n_vars = len(self.var_names_g) self.n_vars = n_vars self.n_components = n_components self.ppca_flavor = ppca_flavor self.elbo = pyro.infer.Trace_ELBO() if isinstance(mean_g, torch.Tensor) and mean_g.dim(): assert mean_g.shape == (n_vars,), f"Expected meang_g to have a shape ({n_vars},) but found {mean_g.shape}." if mean_g is None: # make mean_g a learnable parameter self.mean_g = torch.nn.Parameter(torch.empty(n_vars)) else: self.register_buffer("mean_g", mean_g) self.seed = seed # model parameters self.W_init_scale = W_init_scale self.sigma_init_scale = sigma_init_scale self.W_kg = torch.nn.Parameter(torch.empty(n_components, n_vars)) self.sigma = PyroParam(torch.empty(()), constraint=constraints.positive) self.reset_parameters() def reset_parameters(self) -> None: rng_device = self.W_kg.device.type if self.W_kg.device.type != "meta" else "cpu" rng = torch.Generator(device=rng_device) rng.manual_seed(self.seed) if isinstance(self.mean_g, torch.nn.Parameter): self.mean_g.data.zero_() self.W_kg.data.normal_(0, self.W_init_scale, generator=rng) self.sigma_unconstrained.data.fill_(_unconstrain(torch.as_tensor(self.sigma_init_scale), constraints.positive))
[docs] def forward(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, torch.Tensor | None]: """ Args: x_ng: Gene counts matrix. var_names_g: The list of the variable names in the input data. Returns: A dictionary with the loss value. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) loss = self.elbo.differentiable_loss(self.model, self.guide, x_ng) return {"loss": loss}
def model(self, x_ng: torch.Tensor) -> None: with pyro.plate("cells", size=self.n_obs, subsample_size=x_ng.shape[0]): if self.ppca_flavor == "marginalized": pyro.sample( "counts", dist.LowRankMultivariateNormal( loc=self.mean_g, cov_factor=self.W_kg.T, cov_diag=self.sigma**2 * x_ng.new_ones(self.n_vars), # type: ignore[operator] ), obs=x_ng, ) else: z_nk = pyro.sample( "z", dist.Normal(x_ng.new_zeros(self.n_components), 1).to_event(1), ) pyro.sample( "counts", dist.Normal(self.mean_g + z_nk @ self.W_kg, self.sigma).to_event(1), # type: ignore[arg-type] obs=x_ng, ) def guide(self, x_ng: torch.Tensor) -> None: if self.ppca_flavor == "marginalized": return with pyro.plate("cells", size=self.n_obs, subsample_size=x_ng.shape[0]): V_gk = torch.linalg.solve(self.M_kk, self.W_kg).T D_k = self.sigma / torch.sqrt(torch.diag(self.M_kk)) # type: ignore[arg-type] pyro.sample("z", dist.Normal((x_ng - self.mean_g) @ V_gk, D_k).to_event(1))
[docs] def predict(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, np.ndarray | torch.Tensor]: """ Centering and embedding of the input data ``x_ng`` into the principal component space. .. note:: Gradients are disabled, used for inference only. Args: x_ng: Gene counts matrix. var_names_g: The list of the variable names in the input data. Returns: A dictionary with the following keys: - ``z_nk``: Embedding of the input data into the principal component space. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) V_gk = torch.linalg.solve(self.M_kk, self.W_kg).T z_nk = (x_ng - self.mean_g) @ V_gk return {"z_nk": z_nk}
@property def M_kk(self) -> torch.Tensor: return self.W_kg @ self.W_kg.T + self.sigma**2 * torch.eye(self.n_components, device=self.sigma.device) # type: ignore[attr-defined, operator] @property @torch.inference_mode() def L_k(self) -> torch.Tensor: r""" Vector with elements given by the PC eigenvalues. .. note:: Gradients are disabled, used for inference only. """ S_k = torch.linalg.svdvals(self.W_kg.T) return S_k**2 + self.sigma**2 # type: ignore[operator] @property @torch.inference_mode() def U_gk(self) -> torch.Tensor: r""" Principal components corresponding to eigenvalues ``L_k``. .. note:: Gradients are disabled, used for inference only. """ return torch.linalg.svd(self.W_kg.T, full_matrices=False).U @property @torch.inference_mode() def W_variance(self) -> float: r""" .. note:: Gradients are disabled, used for inference only. """ return torch.trace(self.W_kg.T @ self.W_kg).item() @property @torch.inference_mode() def sigma_variance(self) -> float: r""" .. note:: Gradients are disabled, used for inference only. """ return (self.n_vars * self.sigma**2).item() # type: ignore[operator]