From bd18678478e5ecd18a9fa8a70eedea6f1fcdd036 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 5 Sep 2023 16:02:23 +0800 Subject: [PATCH] [test] fix gemini checkpoint and gpt test (#4620) --- .../test_plugins_huggingface_compatibility.py | 2 +- tests/test_shardformer/test_model/test_shard_gpt2.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py index 3f3b0392a..bd041a5e2 100644 --- a/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py +++ b/tests/test_checkpoint_io/test_plugins_huggingface_compatibility.py @@ -32,7 +32,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per elif plugin_type == 'zero': plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32) elif plugin_type == 'gemini': - plugin = GeminiPlugin(placement_policy='cuda', precision="fp16", initial_scale=32) + plugin = GeminiPlugin(precision="fp16", initial_scale=32) else: raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.") diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 24f5137ae..768063e53 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -102,7 +102,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, torch.cuda.empty_cache() -@pytest.mark.skip(reason="This test will hang in CI") @parameterize('test_config', [{ 'tp_size': 2, 'pp_size': 2, @@ -220,7 +219,7 @@ def check_gpt2_3d(rank, world_size, port): run_gpt2_3d_test() - +@pytest.mark.skip(reason="This test will hang in CI") @pytest.mark.dist @rerun_if_address_is_in_use() @clear_cache_before_run()