Source code for cellarium.ml.models.logistic_regression

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause


import lightning.pytorch as pl
import numpy as np
import pyro
import pyro.distributions as dist
import torch

from cellarium.ml.models.model import CellariumModel, PredictMixin
from cellarium.ml.utilities.testing import (
    assert_arrays_equal,
    assert_columns_and_array_lengths_equal,
)


[docs] class LogisticRegression(CellariumModel, PredictMixin): """ Logistic regression model. Args: n_obs: Number of observations. var_names_g: The variable names schema for the input data validation. y_categories: The categories for the target data. W_prior_scale: The scale of the Laplace prior for the weights. W_init_scale: Initialization scale for the ``W_gc`` parameter. seed: Random seed used to initialize parameters. log_metrics: Whether to log the histogram of the ``W_gc`` parameter. """ def __init__( self, n_obs: int, var_names_g: np.ndarray, y_categories: np.ndarray, W_prior_scale: float = 1.0, W_init_scale: float = 1.0, seed: int = 0, log_metrics: bool = True, ) -> None: super().__init__() # data self.n_obs = n_obs self.var_names_g = var_names_g self.n_vars = len(var_names_g) self.y_categories = y_categories self.n_categories = len(y_categories) self.seed = seed # parameters self._W_prior_scale = W_prior_scale self.W_init_scale = W_init_scale self.W_prior_scale: torch.Tensor self.register_buffer("W_prior_scale", torch.empty(())) self.W_gc = torch.nn.Parameter(torch.empty(self.n_vars, self.n_categories)) self.b_c = torch.nn.Parameter(torch.empty(self.n_categories)) self.reset_parameters() # loss self.elbo = pyro.infer.Trace_ELBO() self.log_metrics = log_metrics def reset_parameters(self) -> None: rng_device = self.W_gc.device.type if self.W_gc.device.type != "meta" else "cpu" rng = torch.Generator(device=rng_device) rng.manual_seed(self.seed) self.W_prior_scale.fill_(self._W_prior_scale) self.W_gc.data.normal_(0, self.W_init_scale, generator=rng) self.b_c.data.zero_()
[docs] def forward( self, x_ng: torch.Tensor, var_names_g: np.ndarray, y_n: torch.Tensor, y_categories: np.ndarray ) -> dict[str, torch.Tensor | None]: """ Args: x_ng: The input data. var_names_g: The variable names for the input data. y_n: The target data. y_categories: The categories for the input target data. Returns: A dictionary with the loss value. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "self.var_names_g", self.var_names_g) assert_arrays_equal("y_categories", y_categories, "self.y_categories", self.y_categories) loss = self.elbo.differentiable_loss(self.model, self.guide, x_ng, y_n) return {"loss": loss}
def model(self, x_ng: torch.Tensor, y_n: torch.Tensor) -> None: W_gc = pyro.sample( "W", dist.Laplace(0, self.W_prior_scale).expand([self.n_vars, self.n_categories]).to_event(2), ) with pyro.plate("batch", size=self.n_obs, subsample_size=x_ng.shape[0]): logits_nc = x_ng @ W_gc + self.b_c pyro.sample("y", dist.Categorical(logits=logits_nc), obs=y_n) def guide(self, x_ng: torch.Tensor, y_n: torch.Tensor) -> None: pyro.sample("W", dist.Delta(self.W_gc).to_event(2))
[docs] def predict(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, np.ndarray | torch.Tensor]: """ Predict the target logits. Args: x_ng: The input data. var_names_g: The variable names for the input data. Returns: A dictionary with the target logits. """ assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "self.var_names_g", self.var_names_g) logits_nc = x_ng @ self.W_gc + self.b_c return {"y_logits_nc": logits_nc}
def on_train_batch_end(self, trainer: pl.Trainer) -> None: if trainer.global_rank != 0: return if not self.log_metrics: return if (trainer.global_step + 1) % trainer.log_every_n_steps != 0: # type: ignore[attr-defined] return for logger in trainer.loggers: if isinstance(logger, pl.loggers.TensorBoardLogger): logger.experiment.add_histogram( "W_gc", self.W_gc, global_step=trainer.global_step, )