[fx] fixed compatiblity issue with torch 1.10 (#1331)

This commit is contained in:
Frank Lee
2022-07-18 11:41:27 +08:00
committed by GitHub
parent 069d6fdc84
commit 75abc75c15
7 changed files with 32 additions and 28 deletions

View File

@@ -1,16 +1,13 @@
import torch
try:
import torchvision.models as tm
except:
pass
import torchvision
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule
from packaging import version
import random
import numpy as np
import inspect
import pytest
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
@@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)