[pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224)

* * fix typehint & docstring in sharder.py

* * update pipeline forward for GPT2Model

* * add test for pipeline forward of GPT2Model

* * add cache cleaning in gpt2 test

* * change assert to raise command
This commit is contained in:
Baizhou Zhang
2023-07-13 15:34:06 +08:00
committed by Hongxin Liu
parent 37d22f6878
commit 208ac8f2ba
5 changed files with 357 additions and 9 deletions

View File

@@ -65,6 +65,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
assert torch.allclose(
org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@@ -77,6 +78,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()