Source code for cellarium.ml.transforms.center_per_cell
# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
import torch
[docs]
class CenterPerCell(torch.nn.Module):
"""
Center each cell by subtracting its mean across genes.
"""
def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
cell_mean_n1 = x_ng.mean(dim=-1, keepdim=True)
x_ng = x_ng - cell_mean_n1
return {"x_ng": x_ng}