mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-10 05:20:33 +00:00
[shardformer] vit/llama/t5 ignore the sequence parallelism flag and some fix. (#4498)
* [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel [shardformer] chatglm support sequence parallel * fix fix fix fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * [shardformer] jit fused fix * activate checks
This commit is contained in:
@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
# check last hidden state & loss
|
||||
if stage_manager is None or stage_manager.is_last_stage():
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
atol, rtol = 2e-4, 2e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
@@ -77,7 +77,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
|
||||
# check weights and gradients
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
atol, rtol = 2e-4, 2e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
@@ -89,7 +89,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
if test_config['precision'] == 'fp32':
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
atol, rtol = 2e-4, 2e-4
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
@@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
|
||||
|
||||
# TODO(jianghai) fix fp16
|
||||
#TODO fix WhisperForConditionalGeneration enable jit fused operator
|
||||
@parameterize('test_config', [{
|
||||
'tp_size': 2,
|
||||
'pp_size': 2,
|
||||
|
Reference in New Issue
Block a user