[fix] rm debug info; update llama policy; update wait handle

This commit is contained in:
duanjunwen
2024-11-15 09:47:05 +00:00
parent cf86c1b1c5
commit 0fb500c7d4
3 changed files with 22 additions and 40 deletions

View File

@@ -691,7 +691,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
# Step3:
# 3-1:detach output; detach output for send fwd;
@@ -896,7 +895,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank)
for it in range(len(schedule)):
scheduled_node = schedule[it]
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}")
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_map[scheduled_node.type]
@@ -925,6 +923,10 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
# wait here to ensure all communication is done
for h in self.wait_handles:
for hh in h:
hh.wait()
# return loss & output
if outputs is not None:
outputs = merge_batch(outputs)