mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 23:18:36 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
@@ -3,7 +3,8 @@ from numpy import isin
|
||||
from torch.fx import GraphModule
|
||||
from torch.utils._pytree import tree_flatten
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
# from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_model_and_compare_output(model, data_gen):
|
||||
|
@@ -1,4 +1,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
@@ -6,6 +9,7 @@ BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_albert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_albert')
|
||||
|
||||
|
@@ -1,8 +1,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_bert():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_bert')
|
||||
|
||||
|
@@ -1,16 +1,24 @@
|
||||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
# TODO: remove this skip once we handle the latest gpt model
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_gpt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
|
||||
# TODO: support the following models
|
||||
# 1. GPT2DoubleHeadsModel
|
||||
# as they are not supported, let's skip them
|
||||
if model.__class__.__name__ in ['GPT2DoubleHeadsModel']:
|
||||
continue
|
||||
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
|
||||
|
||||
|
@@ -1,8 +1,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_opt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
||||
|
@@ -1,8 +1,12 @@
|
||||
import pytest
|
||||
import torch
|
||||
from hf_tracer_utils import trace_model_and_compare_output
|
||||
from packaging import version
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_t5():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
|
Reference in New Issue
Block a user