mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 22:19:47 +00:00
[feature] fit non tensor broadcast (#6218)
This commit is contained in:
parent
de282dd694
commit
2bb71c6248
@ -37,13 +37,22 @@ def ray_broadcast_tensor_dict(
|
|||||||
rank = cc.get_rank(group_name)
|
rank = cc.get_rank(group_name)
|
||||||
if rank == src:
|
if rank == src:
|
||||||
metadata = []
|
metadata = []
|
||||||
|
non_tensor_dict = {}
|
||||||
for k, v in tensor_dict.items():
|
for k, v in tensor_dict.items():
|
||||||
metadata.append((k, v.shape, v.dtype))
|
if isinstance(v, torch.Tensor):
|
||||||
|
metadata.append((k, v.shape, v.dtype))
|
||||||
|
else:
|
||||||
|
non_tensor_dict[k] = v
|
||||||
else:
|
else:
|
||||||
metadata = None
|
metadata = None
|
||||||
metadata = ray_broadcast_object(metadata, src, device, group_name)
|
non_tensor_dict = None
|
||||||
|
|
||||||
|
data_to_broadcast = (metadata, non_tensor_dict)
|
||||||
|
data_to_broadcast = ray_broadcast_object(data_to_broadcast, src, device, group_name)
|
||||||
|
metadata, non_tensor_dict = data_to_broadcast
|
||||||
|
|
||||||
if rank != src:
|
if rank != src:
|
||||||
out_dict = {}
|
out_dict = non_tensor_dict
|
||||||
for k, shape, dtype in metadata:
|
for k, shape, dtype in metadata:
|
||||||
if rank == src:
|
if rank == src:
|
||||||
tensor = tensor_dict[k]
|
tensor = tensor_dict[k]
|
||||||
|
Loading…
Reference in New Issue
Block a user