[Analyzer] fix analyzer tests (#3197)

This commit is contained in:
YuliangLiu0306
2023-03-22 13:38:11 +08:00
committed by GitHub
parent f57d34958b
commit 019a847432
7 changed files with 60 additions and 100 deletions

View File

@@ -1,8 +1,10 @@
import pytest
import timm.models as tmm
import torch
import torchvision.models as tm
from .zoo import tm_models, tmm_models
from packaging import version
from colossalai.testing.utils import parameterize
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
try:
from colossalai._analyzer._subclasses import MetaTensorMode
@@ -16,8 +18,8 @@ def _check_gm_validity(gm: torch.fx.GraphModule):
assert len(node.meta['info'].global_ctx), f'In {gm.__class__.__name__}, {node} has empty global context.'
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tm_models)
def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
@@ -30,8 +32,8 @@ def test_torchvision_profile(m, verbose=False, bias_addition_split=False):
_check_gm_validity(gm)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.parametrize('m', tmm_models)
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@parameterize('m', tmm_models)
def test_timm_profile(m, verbose=False, bias_addition_split=False):
with MetaTensorMode():
model = m()
@@ -45,5 +47,5 @@ def test_timm_profile(m, verbose=False, bias_addition_split=False):
if __name__ == "__main__":
test_torchvision_profile(tm.vit_b_16, verbose=True, bias_addition_split=False)
test_timm_profile(tmm.gmlp_b16_224, verbose=True, bias_addition_split=False)
test_torchvision_profile()
test_timm_profile()