[shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498)

* [shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

[shardformer] chatglm support sequence parallel

* fix

fix

fix

fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* [shardformer] jit fused fix

* activate checks
This commit is contained in:
flybird11111
2023-08-24 15:50:02 +08:00
committed by GitHub
parent e04436a82a
commit 3353e55c80
7 changed files with 46 additions and 21 deletions

View File

@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
@@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights and gradients
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
@@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
@@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# TODOjianghai) fix fp16
#TODO fix WhisperForConditionalGeneration enable jit fused operator
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,