# Copyright Contributors to the Cellarium project.
# SPDX-License-Identifier: BSD-3-Clause
"""
Distributed utilities
---------------------
This module contains helper functions for distributed training.
"""
import warnings
import torch
import torch.distributed as dist
from torch.utils.data import get_worker_info as _get_worker_info
[docs]
class GatherLayer(torch.autograd.Function):
"""Gather tensors from all process, supporting backward propagation."""
@staticmethod
def forward(ctx, input: torch.Tensor) -> tuple[torch.Tensor, ...]: # type: ignore
output = [torch.empty_like(input) for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx, *grads: torch.Tensor) -> torch.Tensor:
all_grads = torch.stack(grads)
dist.all_reduce(all_grads, op=dist.ReduceOp.SUM)
return all_grads[dist.get_rank()]
[docs]
def get_rank_and_num_replicas() -> tuple[int, int]:
"""
This helper function returns the rank of the current process and
the number of processes in the default process group. If distributed
package is not available or default process group has not been initialized
then it returns ``rank=0`` and ``num_replicas=1``.
Returns:
Tuple of ``rank`` and ``num_replicas``.
"""
if not dist.is_available():
num_replicas = 1
rank = 0
else:
try:
num_replicas = dist.get_world_size()
rank = dist.get_rank()
except (ValueError, RuntimeError): # RuntimeError was changed to ValueError in PyTorch 2.2
warnings.warn(
"Distributed package is available but the default process group has not been initialized. "
"Falling back to ``rank=0`` and ``num_replicas=1``.",
UserWarning,
)
num_replicas = 1
rank = 0
if rank >= num_replicas or rank < 0:
raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas-1}]")
return rank, num_replicas
[docs]
def get_worker_info() -> tuple[int, int]:
"""
This helper function returns ``worker_id`` and ``num_workers``. If it is running
in the main process then it returns ``worker_id=0`` and ``num_workers=1``.
Returns:
Tuple of ``worker_id`` and ``num_workers``.
"""
worker_info = _get_worker_info()
if worker_info is None:
worker_id = 0
num_workers = 1
else:
worker_id = worker_info.id
num_workers = worker_info.num_workers
return worker_id, num_workers