Source code for cellarium.ml.models.contrastive_mlp

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

from collections.abc import Sequence

import torch
import torch.nn.functional as F
from torch import nn

from cellarium.ml.losses.nt_xent import NT_Xent
from cellarium.ml.models.model import CellariumModel, PredictMixin


[docs] class ContrastiveMLP(CellariumModel, PredictMixin): """ Multilayer perceptron trained with contrastive learning. Args: n_obs: Number of observations in each entry (network input size). hidden_size: Dimensionality of the fully-connected hidden layers. embed_dim: Size of embedding (network output size). temperature: Parameter governing Normalized Temperature-scaled cross-entropy (NT-Xent) loss. """ def __init__( self, n_obs: int, hidden_size: Sequence[int], embed_dim: int, temperature: float = 1.0, ): super(ContrastiveMLP, self).__init__() self.layers = nn.Sequential() self.layers.append(nn.Linear(n_obs, hidden_size[0])) self.layers.append(nn.BatchNorm1d(hidden_size[0])) self.layers.append(nn.ReLU()) for size_i, size_j in zip(hidden_size[:-1], hidden_size[1:]): self.layers.append(nn.Linear(size_i, size_j)) self.layers.append(nn.BatchNorm1d(size_j)) self.layers.append(nn.ReLU()) self.layers.append(nn.Linear(hidden_size[-1], embed_dim)) self.Xent_loss = NT_Xent(temperature) self.reset_parameters() def reset_parameters(self) -> None: for layer in self.layers: if isinstance(layer, nn.Linear): nn.init.kaiming_uniform_(layer.weight, mode="fan_in", nonlinearity="relu") nn.init.constant_(layer.bias, 0.0) elif isinstance(layer, nn.BatchNorm1d): nn.init.constant_(layer.weight, 1.0) nn.init.constant_(layer.bias, 0.0)
[docs] def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]: """ Args: x_ng: Gene counts matrix. Returns: A dictionary with the loss value. """ # compute deep embeddings z = F.normalize(self.layers(x_ng)) # split input into augmented halves z1, z2 = torch.chunk(z, 2) # SimCLR loss loss = self.Xent_loss(z1, z2) return {"loss": loss}
[docs] def predict(self, x_ng: torch.Tensor): """ Sends (transformed) data through the model and returns outputs. Args: x_ng: Gene counts matrix. Returns: A dictionary with the embedding matrix. """ with torch.no_grad(): z = F.normalize(self.layers(x_ng)) return {"x_ng": z}