mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[ci] fix shardformer tests. (#5255)
* fix ci fix * revert: revert p2p * feat: add enable_metadata_cache option * revert: enable t5 tests --------- Co-authored-by: Wenhao Chen <cwher@outlook.com>
This commit is contained in:
@@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||
)
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt")
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
@@ -200,7 +200,7 @@ def run_gpt2_test(test_config):
|
||||
)
|
||||
@clear_cache_before_run()
|
||||
def run_gpt2_3d_test(test_config):
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt")
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
|
||||
|
Reference in New Issue
Block a user