[Analyzer] fix analyzer tests (#3197)

This commit is contained in:
YuliangLiu0306
2023-03-22 13:38:11 +08:00
committed by GitHub
parent f57d34958b
commit 019a847432
7 changed files with 60 additions and 100 deletions

View File

@@ -1,12 +1,13 @@
import pytest
import torch
import torch.distributed as dist
import torchvision.models as tm
from packaging import version
try:
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
except:
pass
from .zoo import tm_models, tmm_models
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
@@ -28,7 +29,7 @@ def run_and_compare(model):
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
@pytest.mark.parametrize('m', tm_models + tmm_models)
def test_meta_mode_shape(m):
run_and_compare(m())