mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-13 13:11:05 +00:00
[FX] refactor experimental tracer and adapt it with hf models (#3157)
* pass gpt trace and meta_prop * pass t5 trace and meta_prop * [FX] refactor experimental tracer and adapt it with hf models * pass all mainstream model zoo * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * fix CI * skip tests * fix CI * using packaging version * polish
This commit is contained in:
@@ -1,20 +1,18 @@
|
||||
import re
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging import version
|
||||
from torchaudio_utils import trace_and_compare
|
||||
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
def test_torchaudio_models():
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
# FIXME(ver217): temporarily skip these models
|
||||
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
|
||||
continue
|
||||
model = model_fn()
|
||||
trace_and_compare(model,
|
||||
data_gen_fn,
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import torch
|
||||
|
||||
from colossalai.fx import symbolic_trace
|
||||
from colossalai._analyzer.fx import symbolic_trace
|
||||
|
||||
|
||||
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
|
||||
|
Reference in New Issue
Block a user