# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
from collections.abc import Sequence
import torch
import torch.nn.functional as F
from torch import nn
from cellarium.ml.losses.nt_xent import NT_Xent
from cellarium.ml.models.model import CellariumModel, PredictMixin
[docs]
class ContrastiveMLP(CellariumModel, PredictMixin):
"""
Multilayer perceptron trained with contrastive learning.
Args:
n_obs:
Number of observations in each entry (network input size).
hidden_size:
Dimensionality of the fully-connected hidden layers.
embed_dim:
Size of embedding (network output size).
temperature:
Parameter governing Normalized Temperature-scaled cross-entropy (NT-Xent) loss.
"""
def __init__(
self,
n_obs: int,
hidden_size: Sequence[int],
embed_dim: int,
temperature: float = 1.0,
):
super(ContrastiveMLP, self).__init__()
self.layers = nn.Sequential()
self.layers.append(nn.Linear(n_obs, hidden_size[0]))
self.layers.append(nn.BatchNorm1d(hidden_size[0]))
self.layers.append(nn.ReLU())
for size_i, size_j in zip(hidden_size[:-1], hidden_size[1:]):
self.layers.append(nn.Linear(size_i, size_j))
self.layers.append(nn.BatchNorm1d(size_j))
self.layers.append(nn.ReLU())
self.layers.append(nn.Linear(hidden_size[-1], embed_dim))
self.Xent_loss = NT_Xent(temperature)
self.reset_parameters()
def reset_parameters(self) -> None:
for layer in self.layers:
if isinstance(layer, nn.Linear):
nn.init.kaiming_uniform_(layer.weight, mode="fan_in", nonlinearity="relu")
nn.init.constant_(layer.bias, 0.0)
elif isinstance(layer, nn.BatchNorm1d):
nn.init.constant_(layer.weight, 1.0)
nn.init.constant_(layer.bias, 0.0)
[docs]
def forward(self, x_ng: torch.Tensor) -> dict[str, torch.Tensor]:
"""
Args:
x_ng:
Gene counts matrix.
Returns:
A dictionary with the loss value.
"""
# compute deep embeddings
z = F.normalize(self.layers(x_ng))
# split input into augmented halves
z1, z2 = torch.chunk(z, 2)
# SimCLR loss
loss = self.Xent_loss(z1, z2)
return {"loss": loss}
[docs]
def predict(self, x_ng: torch.Tensor):
"""
Sends (transformed) data through the model and returns outputs.
Args:
x_ng:
Gene counts matrix.
Returns:
A dictionary with the embedding matrix.
"""
with torch.no_grad():
z = F.normalize(self.layers(x_ng))
return {"x_ng": z}