# Licensed under the Apache License, Version 2.0.
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/utils/distributed.py
import os
from contextlib import contextmanager
from typing import Generator, Optional, Union
import torch
[docs]def init_distributed(use_cuda: bool) -> None:
"""Initialize distributed backend for parallel training.
This method sets up the distributed backend for parallel training based on the specified
`use_cuda` flag. If `use_cuda` is `True`, it initializes the distributed mode using the
CUDA/NCCL backend. Otherwise, it uses the Gloo backend.
Args:
use_cuda: Whether to initialize the distributed mode using the CUDA/NCCL backend.
Raises:
AssertionError: If the distributed mode is not initialized successfully.
world_size = int(os.environ.get("WORLD_SIZE", 1))
distributed = world_size > 1
if distributed:
backend = "nccl" if use_cuda else "gloo"
torch.distributed.init_process_group(backend=backend, init_method="env://")
assert torch.distributed.is_initialized()
[docs]def barrier() -> None:
"""Synchronize all processes in the distributed backend.
This method calls the `torch.distributed.barrier` function if the distributed mode is
available and initialized. The `barrier` function synchronizes all processes in the
distributed backend by blocking the processes until all processes have reached this point.
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.barrier()
[docs]def get_rank() -> int:
"""Get the rank of the current process in the distributed backend.
Returns:
The rank of the current process in the distributed backend. If the distributed mode
is not available or not initialized, the returned rank will be `0`.
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_rank()
return 0
[docs]def get_world_size() -> int:
"""Get the total number of processes in the distributed backend.
Returns:
The total number of processes in the distributed backend. If the distributed mode
is not available or not initialized, the returned world size will be `1`.
if torch.distributed.is_available() and torch.distributed.is_initialized():
return torch.distributed.get_world_size()
return 1
[docs]def all_reduce(tensor: Union[int, float, torch.Tensor], op: Optional[str] = "sum") -> Union[int, float]:
"""Reduce the input tensor/value into a scalar using the specified reduction operator.
This method applies the specified reduction operator to the input tensor/value in a distributed
manner. The result is a scalar value that is computed by aggregating the values from all
processes in the distributed backend.
Args:
tensor: Input tensor/value to be reduced.
op: Type of reduction operator. The supported operators are "sum", "mean",
"min", "max", and "product".
Returns:
The scalar value obtained by applying the reduction operator to the input
tensor/value. If the distributed mode is not available or not initialized,
the inputvtensor/value is returned as is.
Raises:
RuntimeError: If the specified reduction operator is not supported.
if torch.distributed.is_available() and torch.distributed.is_initialized():
if op == "sum" or op == "mean":
torch_op = torch.distributed.ReduceOp.SUM
elif op == "min":
torch_op = torch.distributed.ReduceOp.MIN
elif op == "max":
torch_op = torch.distributed.ReduceOp.MAX
elif op == "product":
torch_op = torch.distributed.ReduceOp.PRODUCT
else:
raise RuntimeError(f"Operator: {op} is not supported yet.")
backend = torch.distributed.get_backend()
if backend == torch.distributed.Backend.NCCL:
device = torch.device("cuda")
elif backend == torch.distributed.Backend.GLOO:
device = torch.device("cpu")
else:
raise RuntimeError(f"Distributed backend: {backend} is not supported yet.")
tensor = torch.tensor(tensor, device=device)
torch.distributed.all_reduce(tensor, torch_op)
if op == "mean":
tensor /= get_world_size()
return tensor.item()
return tensor
[docs]@contextmanager
def sync_workers() -> Generator[int, None, None]:
"""Context manager for synchronizing the processes in the distributed backend.
This context manager yields the rank of the current process in the distributed backend and
synchronizes all processes on exit.
Yields:
The rank of the current process in the distributed backend.
Example:
>>> with sync_workers():
>>> # Execute some code that should be synchronized across all processes.
rank = get_rank()
yield rank
barrier()