[fix] fix p2p error in zbv

This commit is contained in:
duanjunwen
2024-11-14 09:40:38 +00:00
parent b6d5e61809
commit 1bc4dba3a3
2 changed files with 4 additions and 9 deletions

View File

@@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = False,
overlap_p2p: bool = True,
enable_metadata_cache: bool = True,
overlap_p2p: bool = False,
):
super().__init__(stage_manager)
# Not support overlap_p2p so far
# batch info
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
@@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
for h in self.wait_handles:
for hh in h:
hh.wait()
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
# return loss & output
if outputs is not None: