[shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
Insu Jang
2024-03-27 01:57:00 -04:00
committed by GitHub
parent e6707a6e8d
commit 00525f7772
18 changed files with 136 additions and 106 deletions

View File

@@ -38,9 +38,10 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
org_loss, dist_loss, atol=1e-5
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank]
assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
assert torch.allclose(
target_grad, dist_pred.grad
), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
@pytest.mark.dist