mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-25 12:35:02 +00:00
Merge dd562d28a4
into edd65a84dd
This commit is contained in:
commit
599797ca26
@ -1132,18 +1132,20 @@ def gather_state_dict_fast(
|
|||||||
if rank == dst:
|
if rank == dst:
|
||||||
returned_state_dict = state_dict.copy()
|
returned_state_dict = state_dict.copy()
|
||||||
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
|
dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group)
|
||||||
|
ks, ops = [], []
|
||||||
for i, target_metadata in enumerate(all_meta_data):
|
for i, target_metadata in enumerate(all_meta_data):
|
||||||
if i == dst:
|
if i == dst:
|
||||||
continue
|
continue
|
||||||
ops = []
|
|
||||||
for k, shape, dtype in target_metadata:
|
for k, shape, dtype in target_metadata:
|
||||||
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
|
buffer = torch.empty(shape, dtype=dtype, device=get_current_device())
|
||||||
returned_state_dict[k] = buffer
|
returned_state_dict[k] = buffer
|
||||||
|
ks.append(k)
|
||||||
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
|
ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group))
|
||||||
reqs = dist.batch_isend_irecv(ops)
|
reqs = dist.batch_isend_irecv(ops)
|
||||||
for req, (k, *_) in zip(reqs, target_metadata):
|
for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing
|
||||||
req.wait()
|
req.wait()
|
||||||
returned_state_dict[k] = returned_state_dict[k].to(device)
|
for k in ks:
|
||||||
|
returned_state_dict[k] = returned_state_dict[k].to(device)
|
||||||
return returned_state_dict
|
return returned_state_dict
|
||||||
else:
|
else:
|
||||||
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
|
dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)
|
||||||
|
Loading…
Reference in New Issue
Block a user