[FX] refactor experimental tracer and adapt it with hf models (#3157)

* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
This commit is contained in:
YuliangLiu0306
2023-03-22 10:40:33 +08:00
committed by GitHub
parent b429529365
commit f57d34958b
28 changed files with 1014 additions and 863 deletions

View File

@@ -3,7 +3,8 @@ from numpy import isin
from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten
from colossalai.fx import symbolic_trace
# from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen):

View File

@@ -1,4 +1,7 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@@ -6,6 +9,7 @@ BATCH_SIZE = 2
SEQ_LENGTH = 16
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_albert():
sub_registry = model_zoo.get_sub_registry('transformers_albert')

View File

@@ -1,8 +1,12 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_bert():
sub_registry = model_zoo.get_sub_registry('transformers_bert')

View File

@@ -1,16 +1,24 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
# TODO: remove this skip once we handle the latest gpt model
@pytest.mark.skip
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_gpt():
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
model = model_fn()
# TODO: support the following models
# 1. GPT2DoubleHeadsModel
# as they are not supported, let's skip them
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
continue
trace_model_and_compare_output(model, data_gen_fn)

View File

@@ -1,8 +1,12 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_opt():
sub_registry = model_zoo.get_sub_registry('transformers_opt')

View File

@@ -1,8 +1,12 @@
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_t5():
sub_registry = model_zoo.get_sub_registry('transformers_t5')