18 lines
501 B
Python
18 lines
501 B
Python
# python3.7
|
|
"""Contains the synchronizing operator."""
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
__all__ = ['all_gather']
|
|
|
|
|
|
def all_gather(tensor):
|
|
"""Gathers tensor from all devices and does averaging."""
|
|
if not dist.is_initialized():
|
|
return tensor
|
|
|
|
world_size = dist.get_world_size()
|
|
tensor_list = [torch.ones_like(tensor) for _ in range(world_size)]
|
|
dist.all_gather(tensor_list, tensor, async_op=False)
|
|
return torch.mean(torch.stack(tensor_list, dim=0), dim=0) |