mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-04 21:29:41 +00:00
58 lines
2.1 KiB
Python
58 lines
2.1 KiB
Python
from typing import Any, Dict
|
|
|
|
import ray.util.collective as cc
|
|
import torch
|
|
import torch.distributed.distributed_c10d as c10d
|
|
from packaging.version import Version
|
|
|
|
|
|
def ray_broadcast_object(obj: Any, src: int = 0, device=None, group_name: str = "default") -> Any:
|
|
rank = cc.get_rank(group_name)
|
|
if rank == src:
|
|
if Version(torch.__version__) >= Version("2.3.0"):
|
|
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device, group=None)
|
|
elif Version(torch.__version__) >= Version("1.13.0"):
|
|
obj_tensor, size_tensor = c10d._object_to_tensor(obj, device=device)
|
|
else:
|
|
obj_tensor, size_tensor = c10d._object_to_tensor(obj)
|
|
obj_tensor = obj_tensor.to(device)
|
|
size_tensor = size_tensor.to(device)
|
|
else:
|
|
size_tensor = torch.empty(1, dtype=torch.int64, device=device)
|
|
cc.broadcast(size_tensor, src, group_name)
|
|
if rank != src:
|
|
obj_tensor = torch.empty(size_tensor.item(), dtype=torch.uint8, device=device)
|
|
cc.broadcast(obj_tensor, src, group_name)
|
|
if rank != src:
|
|
if Version(torch.__version__) >= Version("2.3.0"):
|
|
obj = c10d._tensor_to_object(obj_tensor, size_tensor.item(), group=None)
|
|
else:
|
|
obj = c10d._tensor_to_object(obj, size_tensor.item())
|
|
return obj
|
|
|
|
|
|
def ray_broadcast_tensor_dict(
|
|
tensor_dict: Dict[str, torch.Tensor], src: int = 0, device=None, group_name: str = "default"
|
|
) -> Dict[str, torch.Tensor]:
|
|
rank = cc.get_rank(group_name)
|
|
if rank == src:
|
|
metadata = []
|
|
for k, v in tensor_dict.items():
|
|
metadata.append((k, v.shape, v.dtype))
|
|
else:
|
|
metadata = None
|
|
metadata = ray_broadcast_object(metadata, src, device, group_name)
|
|
if rank != src:
|
|
out_dict = {}
|
|
for k, shape, dtype in metadata:
|
|
if rank == src:
|
|
tensor = tensor_dict[k]
|
|
else:
|
|
tensor = torch.empty(shape, dtype=dtype, device=device)
|
|
cc.broadcast(tensor, src, group_name)
|
|
if rank != src:
|
|
out_dict[k] = tensor
|
|
if rank == src:
|
|
out_dict = tensor_dict
|
|
return out_dict
|