[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Insu Jang
2024-03-27 01:57:00 -04:00
committed by GitHub
parent e6707a6e8d
commit 00525f7772
18 changed files with 136 additions and 106 deletions

View File

@@ -83,7 +83,7 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
@parameterize("init_method", ["none", "lazy"])
def check_3d_plugin(init_method: str = "none", early_stop: bool = True):
"""check gemini plugin over model zoo
"""check hybrid plugin over model zoo
Args:
early_stop (bool, optional): Whether to stop when getting the first error. Defaults to True.
@@ -260,7 +260,7 @@ def run_grad_acc_test(test_args):
origin_model, origin_optimizer, dataloader=dataloader
)
for p1, p2 in zip(model.unwrap().parameters(), origin_model.unwrap().parameters()):
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
assert_close(p1.to(p2.dtype), p2, atol=1e-2, rtol=1e-2)
def run_dist(rank, world_size, port, early_stop: bool = True):
@@ -271,9 +271,9 @@ def run_dist(rank, world_size, port, early_stop: bool = True):
@rerun_if_address_is_in_use()
def test_gemini_plugin(early_stop: bool = True):
def test_3d_plugin(early_stop: bool = True):
spawn(run_dist, 4, early_stop=early_stop)
if __name__ == "__main__":
test_gemini_plugin(early_stop=False)
test_3d_plugin(early_stop=False)