# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import lightning.pytorch as pl
import numpy as np
import pyro
import pyro.distributions as dist
import torch
from cellarium.ml.models.model import CellariumModel, PredictMixin, ValidateMixin
from cellarium.ml.utilities.testing import (
assert_arrays_equal,
assert_columns_and_array_lengths_equal,
)
[docs]
class LogisticRegression(CellariumModel, PredictMixin, ValidateMixin):
"""
Logistic regression model.
Args:
n_obs:
Number of observations.
var_names_g:
The variable names schema for the input data validation.
y_categories:
The categories for the target data.
W_prior_scale:
The scale of the Laplace prior for the weights.
W_init_scale:
Initialization scale for the ``W_gc`` parameter.
seed:
Random seed used to initialize parameters.
log_metrics:
Whether to log the histogram of the ``W_gc`` parameter.
"""
def __init__(
self,
n_obs: int,
var_names_g: np.ndarray,
y_categories: np.ndarray,
W_prior_scale: float = 1.0,
W_init_scale: float = 1.0,
seed: int = 0,
log_metrics: bool = True,
) -> None:
super().__init__()
# data
self.n_obs = n_obs
self.var_names_g = var_names_g
self.n_vars = len(var_names_g)
self.y_categories = y_categories
self.n_categories = len(y_categories)
self.seed = seed
# parameters
self._W_prior_scale = W_prior_scale
self.W_init_scale = W_init_scale
self.W_prior_scale: torch.Tensor
self.register_buffer("W_prior_scale", torch.empty(()))
self.W_gc = torch.nn.Parameter(torch.empty(self.n_vars, self.n_categories))
self.b_c = torch.nn.Parameter(torch.empty(self.n_categories))
self.reset_parameters()
# loss
self.elbo = pyro.infer.Trace_ELBO()
self.log_metrics = log_metrics
def reset_parameters(self) -> None:
rng_device = self.W_gc.device.type if self.W_gc.device.type != "meta" else "cpu"
rng = torch.Generator(device=rng_device)
rng.manual_seed(self.seed)
self.W_prior_scale.fill_(self._W_prior_scale)
self.W_gc.data.normal_(0, self.W_init_scale, generator=rng)
self.b_c.data.zero_()
[docs]
def forward(
self, x_ng: torch.Tensor, var_names_g: np.ndarray, y_n: torch.Tensor, y_categories: np.ndarray
) -> dict[str, torch.Tensor | None]:
"""
Args:
x_ng:
The input data.
var_names_g:
The variable names for the input data.
y_n:
The target data.
y_categories:
The categories for the input target data.
Returns:
A dictionary with the loss value.
"""
assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g)
assert_arrays_equal("var_names_g", var_names_g, "self.var_names_g", self.var_names_g)
assert_arrays_equal("y_categories", y_categories, "self.y_categories", self.y_categories)
loss = self.elbo.differentiable_loss(self.model, self.guide, x_ng, y_n)
return {"loss": loss}
def model(self, x_ng: torch.Tensor, y_n: torch.Tensor) -> None:
W_gc = pyro.sample(
"W",
dist.Laplace(0, self.W_prior_scale).expand([self.n_vars, self.n_categories]).to_event(2),
)
with pyro.plate("batch", size=self.n_obs, subsample_size=x_ng.shape[0]):
logits_nc = x_ng @ W_gc + self.b_c
pyro.sample("y", dist.Categorical(logits=logits_nc), obs=y_n)
def guide(self, x_ng: torch.Tensor, y_n: torch.Tensor) -> None:
pyro.sample("W", dist.Delta(self.W_gc).to_event(2))
[docs]
def predict(self, x_ng: torch.Tensor, var_names_g: np.ndarray) -> dict[str, np.ndarray | torch.Tensor]:
"""
Predict the target logits.
Args:
x_ng:
The input data.
var_names_g:
The variable names for the input data.
Returns:
A dictionary with the target logits.
"""
assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g)
assert_arrays_equal("var_names_g", var_names_g, "self.var_names_g", self.var_names_g)
logits_nc = x_ng @ self.W_gc + self.b_c
return {"y_logits_nc": logits_nc}
def on_train_batch_end(self, trainer: pl.Trainer) -> None:
if trainer.global_rank != 0:
return
if not self.log_metrics:
return
if (trainer.global_step + 1) % trainer.log_every_n_steps != 0: # type: ignore[attr-defined]
return
for logger in trainer.loggers:
if isinstance(logger, pl.loggers.TensorBoardLogger):
logger.experiment.add_histogram(
"W_gc",
self.W_gc,
global_step=trainer.global_step,
)