[ci] fix shardformer tests. (#5255)

* fix ci

fix

* revert: revert p2p

* feat: add enable_metadata_cache option

* revert: enable t5 tests

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
flybird11111
2024-01-11 19:07:45 +08:00
committed by GitHub
parent 756c400ad2
commit e830ef917d
4 changed files with 20 additions and 3 deletions

View File

@@ -919,6 +919,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
"""
def __init__(
@@ -956,6 +957,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
enable_metadata_cache: bool = True,
) -> None:
super().__init__()
assert (
@@ -1002,10 +1004,14 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks=num_model_chunks,
num_microbatch=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
elif pp_style == "1f1b":
self.schedule = OneForwardOneBackwardSchedule(
self.stage_manager, num_microbatches=num_microbatches, microbatch_size=microbatch_size
stage_manager=self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size,
enable_metadata_cache=enable_metadata_cache,
)
else:
raise NotImplementedError()