mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-23 11:44:15 +00:00
Merge dd562d28a4
into edd65a84dd
This commit is contained in:
commit
599797ca26
@ -1132,18 +1132,20 @@ def gather_state_dict_fast(
|
||||
if rank == dst:
|
||||
returned_state_dict = state_dict.copy()
|
||||
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
|
||||
ks, ops = [], []
|
||||
for i, target_metadata in enumerate(all_meta_data):
|
||||
if i == dst:
|
||||
continue
|
||||
ops = []
|
||||
for k, shape, dtype in target_metadata:
|
||||
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
|
||||
returned_state_dict[k] = buffer
|
||||
ks.append(k)
|
||||
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req, (k, *_) in zip(reqs, target_metadata):
|
||||
req.wait()
|
||||
returned_state_dict[k] = returned_state_dict[k].to(device)
|
||||
reqs = dist.batch_isend_irecv(ops)
|
||||
for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing
|
||||
req.wait()
|
||||
for k in ks:
|
||||
returned_state_dict[k] = returned_state_dict[k].to(device)
|
||||
return returned_state_dict
|
||||
else:
|
||||
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
|
||||
|
Loading…
Reference in New Issue
Block a user