mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-20 04:32:47 +00:00
[fix] fix p2p error in zbv
This commit is contained in:
parent
b6d5e61809
commit
1bc4dba3a3
@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
num_model_chunks: int,
|
num_model_chunks: int,
|
||||||
num_microbatch: Optional[int] = None,
|
num_microbatch: Optional[int] = None,
|
||||||
microbatch_size: Optional[int] = None,
|
microbatch_size: Optional[int] = None,
|
||||||
enable_metadata_cache: bool = False,
|
enable_metadata_cache: bool = True,
|
||||||
overlap_p2p: bool = True,
|
overlap_p2p: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(stage_manager)
|
super().__init__(stage_manager)
|
||||||
|
# Not support overlap_p2p so far
|
||||||
# batch info
|
# batch info
|
||||||
self.num_microbatch = num_microbatch
|
self.num_microbatch = num_microbatch
|
||||||
self.microbatch_size = microbatch_size
|
self.microbatch_size = microbatch_size
|
||||||
@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||||||
model_chunk_id=scheduled_node.chunk,
|
model_chunk_id=scheduled_node.chunk,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
)
|
)
|
||||||
for h in self.wait_handles:
|
|
||||||
for hh in h:
|
|
||||||
hh.wait()
|
|
||||||
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
||||||
# return loss & output
|
# return loss & output
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
|
@ -770,13 +770,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|||||||
@parameterize(
|
@parameterize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
# Pass
|
|
||||||
(1, 2, 1, 1, 2),
|
(1, 2, 1, 1, 2),
|
||||||
(1, 1, 2, 2, 1),
|
(1, 1, 2, 2, 1),
|
||||||
(1, 2, 1, 2, 1),
|
(1, 2, 1, 2, 1),
|
||||||
(1, 2, 2, 1, 1),
|
(1, 2, 2, 1, 1),
|
||||||
# # TODO: adapt mixtral with no TP Linear
|
(1, 1, 4, 1, 1),
|
||||||
(0, 1, 4, 1, 1),
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
||||||
@ -938,7 +936,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|||||||
(1, 2, 2, 1),
|
(1, 2, 2, 1),
|
||||||
(1, 2, 1, 2),
|
(1, 2, 1, 2),
|
||||||
(1, 1, 2, 2),
|
(1, 1, 2, 2),
|
||||||
# TODO: support overlap p2p in pp4
|
|
||||||
(1, 4, 1, 1),
|
(1, 4, 1, 1),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user