ColossalAI/applications/ColossalChat/coati/distributed/comm.py
2025-02-21 15:24:23 +08:00

58 lines
2.1 KiB
Python

from typing import Any, Dict
import ray.util.collective as cc
import torch
import torch.distributed.distributed_c10d as c10d
from packaging.version import Version
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
rank = cc.get_rank(group_name)
if rank == src:
if Version(torch.__version__) >= Version("2.3.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None)
elif Version(torch.__version__) >= Version("1.13.0"):
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device)
else:
obj_tensor, size_tensor = c10d._object_to_tensor(obj)
obj_tensor = obj_tensor.to(device)
size_tensor = size_tensor.to(device)
else:
size_tensor = torch.empty(1, dtype=torch.int64, device=device)
cc.broadcast(size_tensor, src, group_name)
if rank != src:
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
cc.broadcast(obj_tensor, src, group_name)
if rank != src:
if Version(torch.__version__) >= Version("2.3.0"):
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None)
else:
obj = c10d._tensor_to_object(obj, size_tensor.item())
return obj
def ray_broadcast_tensor_dict(
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
) -> Dict[str, torch.Tensor]:
rank = cc.get_rank(group_name)
if rank == src:
metadata = []
for k, v in tensor_dict.items():
metadata.append((k, v.shape, v.dtype))
else:
metadata = None
metadata = ray_broadcast_object(metadata, src, device, group_name)
if rank != src:
out_dict = {}
for k, shape, dtype in metadata:
if rank == src:
tensor = tensor_dict[k]
else:
tensor = torch.empty(shape, dtype=dtype, device=device)
cc.broadcast(tensor, src, group_name)
if rank != src:
out_dict[k] = tensor
if rank == src:
out_dict = tensor_dict
return out_dict