mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[fix] fix mixtral modeling & policy; update wait handles; doing benchmarking for llama hybrid;
This commit is contained in:
@@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
num_microbatch: Optional[int] = None,
|
||||
microbatch_size: Optional[int] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
overlap_p2p: bool = False,
|
||||
overlap_p2p: bool = True,
|
||||
):
|
||||
super().__init__(stage_manager)
|
||||
# Not support overlap_p2p so far
|
||||
@@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
wait_handle = communication_func(scheduled_node.chunk)
|
||||
self.wait_handles.append(wait_handle)
|
||||
elif scheduled_node.type == "F":
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
self.schedule_f(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
@@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
outputs=outputs,
|
||||
)
|
||||
elif scheduled_node.type == "B":
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
self.schedule_b(
|
||||
scheduled_node=scheduled_node,
|
||||
model_chunk=model_chunk,
|
||||
@@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
Reference in New Issue
Block a user