From 2bb71c6248f1c71eae502de34e6f73b2e17c4f29 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 24 Feb 2025 14:36:04 +0800 Subject: [PATCH] [feature] fit non tensor broadcast (#6218) --- .../ColossalChat/coati/distributed/comm.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/comm.py b/applications/ColossalChat/coati/distributed/comm.py index 3824303f5..e6ecc9f2b 100644 --- a/applications/ColossalChat/coati/distributed/comm.py +++ b/applications/ColossalChat/coati/distributed/comm.py @@ -37,13 +37,22 @@ def ray_broadcast_tensor_dict( rank = cc.get_rank(group_name) if rank == src: metadata = [] + non_tensor_dict = {} for k, v in tensor_dict.items(): - metadata.append((k, v.shape, v.dtype)) + if isinstance(v, torch.Tensor): + metadata.append((k, v.shape, v.dtype)) + else: + non_tensor_dict[k] = v else: metadata = None - metadata = ray_broadcast_object(metadata, src, device, group_name) + non_tensor_dict = None + + data_to_broadcast = (metadata, non_tensor_dict) + data_to_broadcast = ray_broadcast_object(data_to_broadcast, src, device, group_name) + metadata, non_tensor_dict = data_to_broadcast + if rank != src: - out_dict = {} + out_dict = non_tensor_dict for k, shape, dtype in metadata: if rank == src: tensor = tensor_dict[k]