From 6050f37776464c2d6eda52f46b7f81aebb7c0743 Mon Sep 17 00:00:00 2001 From: wukong1992 Date: Mon, 15 May 2023 19:35:21 +0800 Subject: [PATCH] [booster] removed models that don't support fsdp (#3744) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: 纪少敏 --- .../test_plugin/test_torch_fsdp_plugin.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 3f65e48ac..12562095c 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -46,7 +46,10 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_fsdp_plugin(): for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): - if 'diffusers' in name: + if any(element in name for element in [ + 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', + 'torchvision_inception_v3' + ]): continue run_fn(model_fn, data_gen_fn, output_transform_fn) torch.cuda.empty_cache() @@ -58,12 +61,6 @@ def run_dist(rank, world_size, port): check_torch_fsdp_plugin() -# FIXME: this test is not working - - -@pytest.mark.skip( - "ValueError: expected to be in states [, ] but current state is TrainingState_.IDLE" -) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason="requires torch1.12 or higher") @rerun_if_address_is_in_use() def test_torch_fsdp_plugin():