Source code for cellarium.ml.utilities.distributed

# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause

"""
Distributed utilities
---------------------

This module contains helper functions for distributed training.
"""

import warnings

import torch
import torch.distributed as dist
from torch.utils.data import get_worker_info as _get_worker_info


[docs] class GatherLayer(torch.autograd.Function): """Gather tensors from all process, supporting backward propagation.""" @staticmethod def forward(ctx, input: torch.Tensor) -> tuple[torch.Tensor, ...]: # type: ignore output = [torch.empty_like(input) for _ in range(dist.get_world_size())] dist.all_gather(output, input) return tuple(output) @staticmethod def backward(ctx, *grads: torch.Tensor) -> torch.Tensor: all_grads = torch.stack(grads) dist.all_reduce(all_grads, op=dist.ReduceOp.SUM) return all_grads[dist.get_rank()]
[docs] def get_rank_and_num_replicas() -> tuple[int, int]: """ This helper function returns the rank of the current process and the number of processes in the default process group. If distributed package is not available or default process group has not been initialized then it returns ``rank=0`` and ``num_replicas=1``. Returns: Tuple of ``rank`` and ``num_replicas``. """ if not dist.is_available(): num_replicas = 1 rank = 0 else: try: num_replicas = dist.get_world_size() rank = dist.get_rank() except (ValueError, RuntimeError): # RuntimeError was changed to ValueError in PyTorch 2.2 warnings.warn( "Distributed package is available but the default process group has not been initialized. " "Falling back to ``rank=0`` and ``num_replicas=1``.", UserWarning, ) num_replicas = 1 rank = 0 if rank >= num_replicas or rank < 0: raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas-1}]") return rank, num_replicas
[docs] def get_worker_info() -> tuple[int, int]: """ This helper function returns ``worker_id`` and ``num_workers``. If it is running in the main process then it returns ``worker_id=0`` and ``num_workers=1``. Returns: Tuple of ``worker_id`` and ``num_workers``. """ worker_info = _get_worker_info() if worker_info is None: worker_id = 0 num_workers = 1 else: worker_id = worker_info.id num_workers = worker_info.num_workers return worker_id, num_workers