[fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;

This commit is contained in:
duanjunwen
2024-11-15 05:58:56 +00:00
parent 014afbdb59
commit 5c2ebbfd48
4 changed files with 12 additions and 6 deletions

View File

@@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
overlap_p2p: bool = False,
overlap_p2p: bool = True,
):
super().__init__(stage_manager)
# Not support overlap_p2p so far
@@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_f(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
@@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
outputs=outputs,
)
elif scheduled_node.type == "B":
for h in self.wait_handles:
for hh in h:
hh.wait()
self.schedule_b(
scheduled_node=scheduled_node,
model_chunk=model_chunk,
@@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)