diff --git a/applications/ColossalChat/coati/distributed/comm.py b/applications/ColossalChat/coati/distributed/comm.py index 3824303f5..e6ecc9f2b 100644 --- a/applications/ColossalChat/coati/distributed/comm.py +++ b/applications/ColossalChat/coati/distributed/comm.py @@ -37,13 +37,22 @@ def ray_broadcast_tensor_dict( rank = cc.get_rank(group_name) if rank == src: metadata = [] + non_tensor_dict = {} for k, v in tensor_dict.items(): - metadata.append((k, v.shape, v.dtype)) + if isinstance(v, torch.Tensor): + metadata.append((k, v.shape, v.dtype)) + else: + non_tensor_dict[k] = v else: metadata = None - metadata = ray_broadcast_object(metadata, src, device, group_name) + non_tensor_dict = None + + data_to_broadcast = (metadata, non_tensor_dict) + data_to_broadcast = ray_broadcast_object(data_to_broadcast, src, device, group_name) + metadata, non_tensor_dict = data_to_broadcast + if rank != src: - out_dict = {} + out_dict = non_tensor_dict for k, shape, dtype in metadata: if rank == src: tensor = tensor_dict[k]