[fx] fixed compatiblity issue with torch 1.10 (#1331)

This commit is contained in:
Frank Lee 2022-07-18 11:41:27 +08:00 committed by GitHub
parent 069d6fdc84
commit 75abc75c15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 32 additions and 28 deletions

View File

@ -2,6 +2,7 @@ import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
import inspect
@ -233,10 +234,13 @@ def split_module(
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again:

View File

@ -3,6 +3,7 @@ from ..registry import meta_patched_function
@meta_patched_function.register(torch.matmul)
@meta_patched_function.register('matmul') # for built-in op @
def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx
d1 = input.dim()

View File

@ -96,6 +96,9 @@ class ColoTracer(Tracer):
# fetch patched function
if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else:
meta_target = target

View File

@ -1,9 +1,5 @@
import torch
import pytest
try:
import timm.models as tm
except:
pass
import timm.models as tm
from timm_utils import split_model_and_compare_output

View File

@ -1,16 +1,13 @@
import torch
try:
import torchvision.models as tm
except:
pass
import torchvision
import torchvision.models as tm
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule
from packaging import version
import random
import numpy as np
import inspect
import pytest
MANUAL_SEED = 0
random.seed(MANUAL_SEED)
@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)

View File

@ -1,9 +1,5 @@
import torch
import pytest
try:
import timm.models as tm
except:
pass
import timm.models as tm
from colossalai.fx import ColoTracer
from torch.fx import GraphModule

View File

@ -1,9 +1,7 @@
import torch
import pytest
try:
import torchvision.models as tm
except:
pass
import torchvision
import torchvision.models as tm
from packaging import version
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
@ -11,16 +9,22 @@ from torch.fx import GraphModule
def test_torchvision_models():
MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0
tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
]
RANDOMIZED_MODELS = [tm.efficientnet_b0]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
RANDOMIZED_MODELS.append(tm.convnext_small)
torch.backends.cudnn.deterministic = True
tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST:
if model_cls in [tm.convnext_small, tm.efficientnet_b0]:
if model_cls in RANDOMIZED_MODELS:
# remove the impact of randomicity
model = model_cls(stochastic_depth_prob=0)
else: