[feat] support mixtral policy with zbv tp_Linear & non_tp_Linear

This commit is contained in:
duanjunwen
2024-11-12 07:28:49 +00:00
parent 337debcf2a
commit 80b04d7855
4 changed files with 99 additions and 29 deletions

View File

@@ -45,7 +45,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = True,
enable_metadata_cache: bool = False,
overlap_p2p: bool = True,
):
super().__init__(stage_manager)
@@ -679,6 +679,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
accum_loss=accum_loss,
outputs=outputs,
)
# print(f"stage {self.stage_manager.stage}; chunk {model_chunk_id}; output_obj {output_obj}")
# Step3:
# 3-1:detach output; detach output for send fwd;