mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
[Analyzer] fix analyzer tests (#3197)
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torchvision.models as tm
|
||||
from packaging import version
|
||||
|
||||
try:
|
||||
from colossalai._analyzer._subclasses import MetaTensor, MetaTensorMode
|
||||
except:
|
||||
pass
|
||||
from .zoo import tm_models, tmm_models
|
||||
from tests.test_analyzer.test_fx.zoo import tm_models, tmm_models
|
||||
|
||||
|
||||
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor):
|
||||
@@ -28,7 +29,7 @@ def run_and_compare(model):
|
||||
compare_all(x.grad, meta_x.grad)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.__version__ < '1.12.0', reason='torch version < 12')
|
||||
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
|
||||
@pytest.mark.parametrize('m', tm_models + tmm_models)
|
||||
def test_meta_mode_shape(m):
|
||||
run_and_compare(m())
|
||||
|
Reference in New Issue
Block a user