diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 4b36dbe00..ea97da1ba 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -1132,18 +1132,20 @@ def gather_state_dict_fast( if rank == dst: returned_state_dict = state_dict.copy() dist.gather_object(metadata, all_meta_data, dst=dist.get_global_rank(group, rank), group=group) + ks, ops = [], [] for i, target_metadata in enumerate(all_meta_data): if i == dst: continue - ops = [] for k, shape, dtype in target_metadata: buffer = torch.empty(shape, dtype=dtype, device=get_current_device()) returned_state_dict[k] = buffer + ks.append(k) ops.append(dist.P2POp(dist.irecv, buffer, dist.get_global_rank(group, i), group)) - reqs = dist.batch_isend_irecv(ops) - for req, (k, *_) in zip(reqs, target_metadata): - req.wait() - returned_state_dict[k] = returned_state_dict[k].to(device) + reqs = dist.batch_isend_irecv(ops) + for req in reqs: # len(reqs) maybe be different from len(ops) because of coalescing + req.wait() + for k in ks: + returned_state_dict[k] = returned_state_dict[k].to(device) return returned_state_dict else: dist.gather_object(metadata, dst=dist.get_global_rank(group, dst), group=group)