[fix] fix p2p error in zbv

This commit is contained in:
duanjunwen 2024-11-14 09:40:38 +00:00
parent b6d5e61809
commit 1bc4dba3a3
2 changed files with 4 additions and 9 deletions

View File

@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
num_model_chunks: int, num_model_chunks: int,
num_microbatch: Optional[int] = None, num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = False, enable_metadata_cache: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = False,
): ):
super().__init__(stage_manager) super().__init__(stage_manager)
# Not support overlap_p2p so far
# batch info # batch info
self.num_microbatch = num_microbatch self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
optimizer=optimizer, optimizer=optimizer,
) )
for h in self.wait_handles:
for hh in h:
hh.wait()
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
# return loss & output # return loss & output
if outputs is not None: if outputs is not None:

View File

@ -770,13 +770,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
@parameterize( @parameterize(
"config", "config",
[ [
# Pass
(1, 2, 1, 1, 2), (1, 2, 1, 1, 2),
(1, 1, 2, 2, 1), (1, 1, 2, 2, 1),
(1, 2, 1, 2, 1), (1, 2, 1, 2, 1),
(1, 2, 2, 1, 1), (1, 2, 2, 1, 1),
# # TODO: adapt mixtral with no TP Linear (1, 1, 4, 1, 1),
(0, 1, 4, 1, 1),
], ],
) )
def run_with_booster_moehybridplugin(config: Tuple[int, ...]): def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@ -938,7 +936,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
(1, 2, 2, 1), (1, 2, 2, 1),
(1, 2, 1, 2), (1, 2, 1, 2),
(1, 1, 2, 2), (1, 1, 2, 2),
# TODO: support overlap p2p in pp4
(1, 4, 1, 1), (1, 4, 1, 1),
], ],
) )