Source code for cellarium.ml.models.hvg_seurat_v3

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import warnings
from typing import Literal

import lightning.pytorch as pl
import numpy as np
import pandas as pd
import skmisc.loess
import torch
import torch.distributed as dist
from lightning.pytorch.strategies import DDPStrategy

from cellarium.ml.core.datamodule import CellariumAnnDataDataModule
from cellarium.ml.models.model import CellariumModel
from cellarium.ml.utilities.testing import (
    assert_arrays_equal,
    assert_columns_and_array_lengths_equal,
)


def _fit_loess_with_jitter(
    x: np.ndarray,
    y: np.ndarray,
    span: float,
    max_jitter: float = 1e-6,
    initial_jitter: float = 1e-18,
    seed: int = 0,
) -> np.ndarray:
    rng = np.random.default_rng(seed)
    jitter = 0.0

    while jitter <= max_jitter:
        x_fit = x if jitter == 0 else x + rng.uniform(-jitter, jitter, size=x.shape[0])

        try:
            model = skmisc.loess.loess(x_fit, y, span=span, degree=2)
            model.fit()
            return model.outputs.fitted_values
        except ValueError:
            jitter = initial_jitter if jitter == 0 else jitter * 10

    raise ValueError(f"LOESS fit failed after retrying with jitter up to {max_jitter}.")


[docs] class HVGSeuratV3(CellariumModel): """ Compute highly variable genes using the Seurat v3 method in two Lightning epochs. **Epoch 0** — streams data to accumulate per-batch mean and variance. Between epochs — fits a LOESS model of ``log10(var) ~ log10(mean)`` per batch to estimate a regularized standard deviation and per-cell clip value. **Epoch 1** — streams data again, clips counts per cell at the batch-level ``clip_val``, and accumulates clipped sums. After epoch 1 (``on_train_epoch_end``) — computes normalized variance per batch, ranks genes, and writes ``self.hvg_dfs`` (and optionally a CSV file at ``output_path``). The ``flavor`` argument controls how genes are ranked across batches when ``n_batch > 1``, matching the two multi-batch modes in Scanpy: * ``"seurat_v3"`` *(default)* — sorts by ``highly_variable_rank`` ascending first, then by ``highly_variable_nbatches`` descending as tiebreaker. Genes with the lowest median rank across batches are preferred. This matches Scanpy's ``flavor='seurat_v3'``. * ``"seurat_v3_paper"`` — sorts by ``highly_variable_nbatches`` descending first, then by ``highly_variable_rank`` ascending as tiebreaker. Genes that are highly variable in more batches are preferred, regardless of their exact rank. This matches Scanpy's ``flavor='seurat_v3_paper'``. Usage:: model = HVGSeuratV3(var_names_g=gene_names, n_top_genes=2000, n_batch=4, flavor="seurat_v3_paper", span=0.3) trainer = pl.Trainer(max_epochs=2) trainer.fit(module, datamodule) df = model.hvg_df # pandas DataFrame, Scanpy-compatible columns Args: var_names_g: Array of gene names, length ``n_genes``. n_top_genes: Number of highly variable genes to select. Can be a list of ints to produce multiple gene sets in a single training run. n_batch: Number of batches (use 1 when no batch information is given). flavor: Multi-batch gene ranking strategy. ``"seurat_v3_paper"`` (default) prioritises batch consistency; ``"seurat_v3"`` prioritises median rank. use_batch_key: Whether to expect a ``batch_index_n`` batch key in the dataloader output. If False, the model will ignore batch keys and treat all data as a single batch. span: LOESS span (fraction of data used per local fit). Default 0.3. output_path: If given, the result DataFrame is written to this CSV filepath after training. (Ends with ``.csv``). If ``None``, no file is written, but you really want to write this output, as it would require manually calling _compute_hvg_df() later. batch_n_cell_minimum: Minimum number of cells required for a batch to be considered valid. n_batch_minimum: Minimum number of batches in which the gene is in n_top_genes highly variable for a gene to make the final list. """ _VALID_FLAVORS = frozenset({"seurat_v3", "seurat_v3_paper"}) def __init__( self, var_names_g: np.ndarray, n_top_genes: int | list[int], n_batch: int = 1, flavor: Literal["seurat_v3", "seurat_v3_paper"] = "seurat_v3", use_batch_key: bool = False, span: float = 0.3, output_path: str | None = "hvg_seurat_v3_output.csv", batch_n_cell_minimum: int = 2, n_batch_minimum: int = 1, ) -> None: super().__init__() if flavor not in self._VALID_FLAVORS: raise ValueError(f"flavor must be one of {sorted(self._VALID_FLAVORS)}, got {flavor!r}") self.var_names_g = var_names_g n_vars = len(var_names_g) self.n_vars = n_vars if isinstance(n_top_genes, int): self.n_top_genes_list = [n_top_genes] else: self.n_top_genes_list = sorted(set(n_top_genes)) self.n_batch = n_batch self.flavor = flavor self.use_batch_key = use_batch_key if use_batch_key and n_batch < 2: raise ValueError( "n_batch must be at least 2 when use_batch_key is True. This error may also be triggered " "if your dataloader is not providing the expected `batch_index_n` key: check your dataloader " "batch_keys and ensure `batch_index_n` is included when use_batch_key=True --" "data:" " ..." " batch_keys:" " ..." " batch_index_n:" " attr: obs" " key: my_categorical_batch_column" " convert_fn: cellarium.ml.utilities.data.categories_to_codes" ) self.span = span if (output_path is not None) and (not output_path.endswith(".csv")): raise ValueError("output_path must end with .csv") self.output_path = output_path self.batch_n_cell_minimum = batch_n_cell_minimum if n_batch_minimum > n_batch: raise ValueError(f"n_batch_minimum ({n_batch_minimum}) cannot exceed n_batch ({n_batch}).") self.n_batch_minimum = n_batch_minimum self._current_epoch: int = 0 # Epoch-0 buffers: shape (n_batch, n_vars) self.x_sums_bg: torch.Tensor self.x_squared_sums_bg: torch.Tensor self.x_size_b: torch.Tensor self.register_buffer("x_sums_bg", torch.zeros(n_batch, n_vars, dtype=torch.float64)) self.register_buffer("x_squared_sums_bg", torch.zeros(n_batch, n_vars, dtype=torch.float64)) self.register_buffer("x_size_b", torch.zeros(n_batch)) # Set between epochs by on_train_epoch_end after epoch 0: shape (n_batch, n_vars) self.clip_val_bg: torch.Tensor self.reg_std_bg: torch.Tensor self.register_buffer("clip_val_bg", torch.zeros(n_batch, n_vars, dtype=torch.float64)) self.register_buffer("reg_std_bg", torch.zeros(n_batch, n_vars, dtype=torch.float64)) # Epoch-1 buffers: shape (n_batch, n_vars) self.counts_sum_bg: torch.Tensor self.sq_counts_sum_bg: torch.Tensor self.register_buffer("counts_sum_bg", torch.zeros(n_batch, n_vars, dtype=torch.float64)) self.register_buffer("sq_counts_sum_bg", torch.zeros(n_batch, n_vars, dtype=torch.float64)) # Dummy parameter so Lightning treats this as a trainable module self._dummy_param = torch.nn.Parameter(torch.empty(())) self._dummy_param.data.zero_() self.hvg_dfs: dict[int, pd.DataFrame] | None = None def reset_parameters(self) -> None: self.x_sums_bg.zero_() self.x_squared_sums_bg.zero_() self.x_size_b.zero_() self.clip_val_bg.zero_() self.reg_std_bg.zero_() self.counts_sum_bg.zero_() self.sq_counts_sum_bg.zero_() self._dummy_param.data.zero_()
[docs] def on_train_epoch_start(self, trainer: pl.Trainer) -> None: """Cache the current epoch number so forward() can branch on it.""" self._current_epoch = trainer.current_epoch
def forward( self, x_ng: torch.Tensor, var_names_g: np.ndarray, batch_index_n: torch.Tensor | None = None, ) -> dict[str, torch.Tensor | None]: assert_columns_and_array_lengths_equal("x_ng", x_ng, "var_names_g", var_names_g) assert_arrays_equal("var_names_g", var_names_g, "var_names_g", self.var_names_g) # help users avoid unintentional errors in configuring the use of batch_key if self.use_batch_key and (batch_index_n is None): raise ValueError( "batch_index_n is required when use_batch_key is True: add `batch_index_n` to your " "dataloader batch_keys." ) if (not self.use_batch_key) and (batch_index_n is not None): raise ValueError( "batch_index_n was given but use_batch_key is False: either set use_batch_key=True or remove " "batch_index_n from your dataloader batch_keys." ) if batch_index_n is None: batch_index_n = torch.zeros(x_ng.shape[0], dtype=torch.long, device=x_ng.device) else: batch_index_n = batch_index_n.long() # needed for scatter_add_ if self._current_epoch == 0: self._accumulate_epoch0(x_ng, batch_index_n) elif self._current_epoch == 1: self._accumulate_epoch1(x_ng, batch_index_n) else: raise RuntimeError(f"HVGSeuratV3 expects max_epochs=2, but got epoch {self._current_epoch}.") return {} def _accumulate_epoch0(self, x_ng: torch.Tensor, batch_idx_n: torch.Tensor) -> None: n_cells = x_ng.shape[0] x64_ng = x_ng.double() idx_exp_ng = batch_idx_n.unsqueeze(1).expand(n_cells, self.n_vars) sums_contrib_bg = torch.zeros(self.n_batch, self.n_vars, dtype=torch.float64, device=x_ng.device) sq_sums_contrib_bg = torch.zeros(self.n_batch, self.n_vars, dtype=torch.float64, device=x_ng.device) sums_contrib_bg.scatter_add_(0, idx_exp_ng, x64_ng) sq_sums_contrib_bg.scatter_add_(0, idx_exp_ng, x64_ng**2) self.x_sums_bg = self.x_sums_bg + sums_contrib_bg self.x_squared_sums_bg = self.x_squared_sums_bg + sq_sums_contrib_bg self.x_size_b = self.x_size_b + torch.bincount(batch_idx_n, minlength=self.n_batch) def _accumulate_epoch1(self, x_ng: torch.Tensor, batch_idx_n: torch.Tensor) -> None: n_cells = x_ng.shape[0] x64_ng = x_ng.double() # Per-cell clip value: shape (n_cells, n_vars) per_cell_clip_ng = self.clip_val_bg[batch_idx_n] # float64 x_clipped_ng = torch.minimum(x64_ng, per_cell_clip_ng) idx_exp_ng = batch_idx_n.unsqueeze(1).expand(n_cells, self.n_vars) sums_contrib_bg = torch.zeros(self.n_batch, self.n_vars, dtype=torch.float64, device=x_ng.device) sq_sums_contrib_bg = torch.zeros(self.n_batch, self.n_vars, dtype=torch.float64, device=x_ng.device) sums_contrib_bg.scatter_add_(0, idx_exp_ng, x_clipped_ng) sq_sums_contrib_bg.scatter_add_(0, idx_exp_ng, x_clipped_ng**2) self.counts_sum_bg = self.counts_sum_bg + sums_contrib_bg self.sq_counts_sum_bg = self.sq_counts_sum_bg + sq_sums_contrib_bg
[docs] def on_train_start(self, trainer: pl.Trainer) -> None: """Validation: if in a distributed setting, use DDP with broadcast_buffers=False.""" if trainer.world_size > 1: if not isinstance(trainer.strategy, DDPStrategy): raise ValueError("HVGSeuratV3 requires the DDP strategy.") if trainer.strategy._ddp_kwargs.get("broadcast_buffers") is not False: raise ValueError("HVGSeuratV3 requires broadcast_buffers=False.")
[docs] def on_train_epoch_end(self, trainer: pl.Trainer) -> None: """Implement the two-step Seurat v3 HVG method using hooks at the end of first and second epochs. After epoch 0, reduce buffers and fit LOESS to set clip_val_bg and reg_std_bg. After epoch 1, reduce buffers and compute hvg_df. """ if trainer.current_epoch == 0: self._finish_epoch0(trainer) elif trainer.current_epoch == 1: self._finish_epoch1(trainer)
def _finish_epoch0(self, trainer: pl.Trainer) -> None: # 1. Reduce epoch-0 buffers to rank 0 if trainer.world_size > 1: dist.reduce(self.x_sums_bg, dst=0, op=dist.ReduceOp.SUM) dist.reduce(self.x_squared_sums_bg, dst=0, op=dist.ReduceOp.SUM) dist.reduce(self.x_size_b, dst=0, op=dist.ReduceOp.SUM) # 2. Rank 0: compute mean/var per batch; fit LOESS; set clip_val_bg and reg_std_bg if trainer.global_rank == 0: self._compute_clip_val() # 3. Broadcast clip_val_bg to all ranks so epoch-1 can use it if trainer.world_size > 1: dist.broadcast(self.clip_val_bg, src=0) dist.broadcast(self.reg_std_bg, src=0) def _compute_clip_val(self) -> None: # Default to +inf so any batch skipped below (N < batch_n_cell_minimum) acts as a no-op clip. self.clip_val_bg.fill_(float("inf")) n_vars = self.n_vars for b in range(self.n_batch): N = self.x_size_b[b].item() if N < self.batch_n_cell_minimum: continue sums = self.x_sums_bg[b].cpu().numpy() sq_sums = self.x_squared_sums_bg[b].cpu().numpy() mean_g = sums / N # Unbiased (sample) variance with Bessel's correction, matching Scanpy's correction=1 var_g = (sq_sums - sums**2 / N) / (N - 1) not_const = var_g > 0 estimated_var = np.zeros(n_vars, dtype=np.float64) if not_const.any(): x = np.log10(mean_g[not_const]) y = np.log10(var_g[not_const]) estimated_var[not_const] = _fit_loess_with_jitter(x, y, span=self.span) reg_std = np.sqrt(10**estimated_var) # shape (n_vars,) clip_val = reg_std * np.sqrt(N) + mean_g self.reg_std_bg[b] = torch.tensor(reg_std) self.clip_val_bg[b] = torch.tensor(clip_val) def _finish_epoch1(self, trainer: pl.Trainer) -> None: # 1. Reduce epoch-1 buffers to rank 0 if trainer.world_size > 1: dist.reduce(self.counts_sum_bg, dst=0, op=dist.ReduceOp.SUM) dist.reduce(self.sq_counts_sum_bg, dst=0, op=dist.ReduceOp.SUM) # 2. Rank 0: compute norm_var, rank, select top genes, build DataFrame and optionally save to CSV if trainer.global_rank == 0: # Retrieve adata.var annotation columns to enrich the output DataFrame. var_df: pd.DataFrame | None = None datamodule = getattr(trainer, "datamodule", None) if isinstance(datamodule, CellariumAnnDataDataModule): var_df = datamodule.dadc.schema.attr_values["var"] else: warnings.warn( "HVGSeuratV3: trainer.datamodule is not a CellariumAnnDataDataModule; " "adata.var annotations will not be added to the output DataFrame.", UserWarning, stacklevel=2, ) self.hvg_dfs = {} for n in self.n_top_genes_list: df = self._compute_hvg_df(n_top_genes=n, var_df=var_df) self.hvg_dfs[n] = df if self.output_path is not None: path = self.output_path.replace(".csv", f"__top{n}.csv") self._save(df=df, output_path=path) def _compute_hvg_df(self, n_top_genes: int, var_df: pd.DataFrame | None = None) -> pd.DataFrame: n_vars = self.n_vars norm_gene_var_bg = np.zeros((self.n_batch, n_vars), dtype=np.float64) for b in range(self.n_batch): N = self.x_size_b[b].item() if N < self.batch_n_cell_minimum: continue mean_g = (self.x_sums_bg[b] / N).cpu().numpy().astype(np.float64) reg_std_g = self.reg_std_bg[b].cpu().numpy().astype(np.float64) sum_g = self.counts_sum_bg[b].cpu().numpy().astype(np.float64) sq_sum_g = self.sq_counts_sum_bg[b].cpu().numpy().astype(np.float64) denom_g = (N - 1) * reg_std_g**2 with np.errstate(divide="ignore", invalid="ignore"): norm_gene_var_g = (N * mean_g**2 + sq_sum_g - 2 * mean_g * sum_g) / denom_g norm_gene_var_g[np.isnan(norm_gene_var_g)] = 0.0 norm_gene_var_bg[b] = norm_gene_var_g # Only rank over batches with sufficient data (N >= batch_n_cell_minimum); invalid batches # have norm_gene_vars == 0 and would otherwise pollute num_batches_high_var # and median_ranked with arbitrary rankings of equal values. valid_b = np.array([self.x_size_b[b].item() >= self.batch_n_cell_minimum for b in range(self.n_batch)]) norm_gene_vars_valid_vg = norm_gene_var_bg[valid_b] # Rank genes within each valid batch v hvg_rank_vg = np.argsort(np.argsort(-norm_gene_vars_valid_vg, axis=1), axis=1).astype(np.float32) num_batches_high_var_v = (hvg_rank_vg < n_top_genes).sum(axis=0).astype(int) hvg_rank_vg[hvg_rank_vg >= n_top_genes] = np.nan masked_hvg_rank_vg = np.ma.masked_invalid(hvg_rank_vg) median_hvg_rank_g = np.ma.median(masked_hvg_rank_vg, axis=0).filled(np.nan) variances_norm_g = norm_gene_vars_valid_vg.mean(axis=0) df = pd.DataFrame( index=pd.Index(self.var_names_g, name="gene"), data={ "highly_variable_nbatches": num_batches_high_var_v, "highly_variable_rank": median_hvg_rank_g, "variances_norm": variances_norm_g, }, ) # Use integer-position sort so duplicate gene names don't cause # df.loc to mark more than n_top_genes rows as highly variable. rank_vals_g = df["highly_variable_rank"].fillna(np.inf).values nbatches_vals_g = df["highly_variable_nbatches"].values # np.lexsort: LAST key = primary sort key. if self.flavor == "seurat_v3": # Primary: rank ascending; tiebreaker: nbatches descending. sort_positions_g = np.lexsort((-nbatches_vals_g, rank_vals_g)) else: # seurat_v3_paper # Primary: nbatches descending; tiebreaker: rank ascending. sort_positions_g = np.lexsort((rank_vals_g, -nbatches_vals_g)) eligible_g = df["highly_variable_nbatches"].values >= self.n_batch_minimum eligible_in_sort_order_g = sort_positions_g[eligible_g[sort_positions_g]] hvg_flags = np.zeros(len(df), dtype=bool) hvg_flags[eligible_in_sort_order_g[:n_top_genes]] = True df["highly_variable"] = hvg_flags if self.n_batch == 1: df = df.drop(columns=["highly_variable_nbatches"]) # Append any extra columns from adata.var (e.g. gene_id, gene_name). # df.index values are var_names_g; var_df.index is adata.var_names. # pandas .join() aligns on label values, so index *name* differences # are harmless, and genes absent from var_df receive NaN. if var_df is not None: extra_cols = var_df.columns.difference(df.columns) if len(extra_cols) > 0: df = df.join(var_df[extra_cols], how="left") df = df.iloc[sort_positions_g] # reorder rows by rank return df def _save(self, df: pd.DataFrame, output_path: str) -> None: df.to_csv(output_path) hvg_df = df[df["highly_variable"]] hvg_df.to_csv(output_path.replace(".csv", "__hvg_only.csv"), index=True, header=True)