Source code for cellarium.ml.transforms.duplicate

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause


import torch
from torch import nn


[docs] class Duplicate(nn.Module): """ Duplicates every row of the input tensor, used for contrastive augmentations. """
[docs] def __init__(self, enabled=True): """ Args: enabled: If True, performs duplication; otherwise does nothing. Set False when performing model inference so the transformation pipeline remains consistent with training. """ super().__init__() self.enabled = enabled
[docs] def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]: """ Args: x_ng: Gene counts. Returns: Duplicated counts. """ if self.enabled: x_ng = x_ng.repeat((2, 1)) return {"x_ng": x_ng}