mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[fix] rm debug info; update llama policy; update wait handle
This commit is contained in:
@@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
accum_loss=accum_loss,
|
||||
outputs=outputs,
|
||||
)
|
||||
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
|
||||
|
||||
# Step3:
|
||||
# 3-1:detach output; detach output for send fwd;
|
||||
@@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
|
||||
for it in range(len(schedule)):
|
||||
scheduled_node = schedule[it]
|
||||
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
@@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
||||
model_chunk_id=scheduled_node.chunk,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
# wait here to ensure all communication is done
|
||||
for h in self.wait_handles:
|
||||
for hh in h:
|
||||
hh.wait()
|
||||
# return loss & output
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
|
Reference in New Issue
Block a user