mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
[pipeline] Add Pipeline Forward for GPT2Model Shardformer (#4224)
* * fix typehint & docstring in sharder.py * * update pipeline forward for GPT2Model * * add test for pipeline forward of GPT2Model * * add cache cleaning in gpt2 test * * change assert to raise command
This commit is contained in:
committed by
Hongxin Liu
parent
37d22f6878
commit
208ac8f2ba
@@ -65,6 +65,7 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
|
||||
assert torch.allclose(
|
||||
org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to origin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@@ -77,6 +78,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user