diff --git a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py index 58c8132e1..e6f8df2e0 100644 --- a/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py +++ b/tests/test_fx/test_tracer/test_hf_model/hf_tracer_utils.py @@ -22,6 +22,7 @@ def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = Non try: meta_args = {k: v.to('meta') for k, v in inputs.items()} gm = symbolic_trace(model, meta_args=meta_args) + except Exception as e: raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py index 632ad366c..7773de480 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_bert.py @@ -14,6 +14,8 @@ def test_bert(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items(): model = model_fn() + if model.__class__.__name__ == "BertForQuestionAnswering": + continue trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label']) diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py index 31bcb7028..e29afe786 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_gpt.py @@ -18,7 +18,7 @@ def test_gpt(): # TODO: support the following models # 1. GPT2DoubleHeadsModel # as they are not supported, let's skip them - if model.__class__.__name__ in ['GPT2DoubleHeadsModel']: + if model.__class__.__name__ in ['GPT2DoubleHeadsModel', 'GPT2ForQuestionAnswering']: continue trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels']) diff --git a/tests/test_shardformer/test_model/test_pure_pipeline.py b/tests/test_shardformer/test_model/test_pure_pipeline.py index 24cda193a..80767f71c 100644 --- a/tests/test_shardformer/test_model/test_pure_pipeline.py +++ b/tests/test_shardformer/test_model/test_pure_pipeline.py @@ -122,9 +122,7 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la 2: [2, 3], 3: [2, 3], } - from datasets import load_dataset - #dataset = load_dataset("open_subtitles", lang1="fi", lang2="hi") pg_mesh = ProcessGroupMesh(PP_SIZE) stage_manager = PipelineStageManager(pg_mesh, PP_DIM) sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')