This commit is contained in:
Pavel Belevich 2025-07-15 15:31:11 +08:00 committed by GitHub
commit 599797ca26
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1132,18 +1132,20 @@ def gather_state_dict_fast(
if rank == dst:
returned_state_dict = state_dict.copy()
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
ks, ops = [], []
for i, target_metadata in enumerate(all_meta_data):
if i == dst:
continue
ops = []
for k, shape, dtype in target_metadata:
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
returned_state_dict[k] = buffer
ks.append(k)
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
reqs = dist.batch_isend_irecv(ops)
for req, (k, *_) in zip(reqs, target_metadata):
req.wait()
returned_state_dict[k] = returned_state_dict[k].to(device)
reqs = dist.batch_isend_irecv(ops)
for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing
req.wait()
for k in ks:
returned_state_dict[k] = returned_state_dict[k].to(device)
return returned_state_dict
else:
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)