[fx] fixed compatiblity issue with torch 1.10 (#1331)

This commit is contained in:
Frank Lee 2022-07-18 11:41:27 +08:00 committed by GitHub
parent 069d6fdc84
commit 75abc75c15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 32 additions and 28 deletions

View File

@ -2,6 +2,7 @@ import torch
from torch.fx.graph_module import GraphModule from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility from torch.fx._compatibility import compatibility
from packaging import version
import inspect import inspect
@ -233,6 +234,9 @@ def split_module(
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes: for node in m.graph.nodes:
if node.op == 'placeholder': if node.op == 'placeholder':
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type, type_expr=node.type,

View File

@ -3,6 +3,7 @@ from ..registry import meta_patched_function
@meta_patched_function.register(torch.matmul) @meta_patched_function.register(torch.matmul)
@meta_patched_function.register('matmul') # for built-in op @
def torch_matmul(input, other, *, out=None): def torch_matmul(input, other, *, out=None):
# copied from huggingface.utils.fx # copied from huggingface.utils.fx
d1 = input.dim() d1 = input.dim()

View File

@ -96,6 +96,9 @@ class ColoTracer(Tracer):
# fetch patched function # fetch patched function
if meta_patched_function.has(target): if meta_patched_function.has(target):
meta_target = meta_patched_function.get(target) meta_target = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
# use name for some builtin op like @ (matmul)
meta_target = meta_patched_function.get(target.__name__)
else: else:
meta_target = target meta_target = target

View File

@ -1,9 +1,5 @@
import torch import torch
import pytest import timm.models as tm
try:
import timm.models as tm
except:
pass
from timm_utils import split_model_and_compare_output from timm_utils import split_model_and_compare_output

View File

@ -1,16 +1,13 @@
import torch import torch
try: import torchvision
import torchvision.models as tm import torchvision.models as tm
except:
pass
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from torch.fx import GraphModule from torch.fx import GraphModule
from packaging import version
import random import random
import numpy as np import numpy as np
import inspect import inspect
import pytest
MANUAL_SEED = 0 MANUAL_SEED = 0
random.seed(MANUAL_SEED) random.seed(MANUAL_SEED)
@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
def test_torchvision_models(): def test_torchvision_models():
MODEL_LIST = [ MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5 tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
] ]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
tracer = ColoTracer() tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224) data = torch.rand(2, 3, 224, 224)

View File

@ -1,9 +1,5 @@
import torch import torch
import pytest import timm.models as tm
try:
import timm.models as tm
except:
pass
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from torch.fx import GraphModule from torch.fx import GraphModule

View File

@ -1,9 +1,7 @@
import torch import torch
import pytest import torchvision
try: import torchvision.models as tm
import torchvision.models as tm from packaging import version
except:
pass
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from torch.fx import GraphModule from torch.fx import GraphModule
@ -11,16 +9,22 @@ from torch.fx import GraphModule
def test_torchvision_models(): def test_torchvision_models():
MODEL_LIST = [ MODEL_LIST = [
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0 tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
] ]
RANDOMIZED_MODELS = [tm.efficientnet_b0]
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
RANDOMIZED_MODELS.append(tm.convnext_small)
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
tracer = ColoTracer() tracer = ColoTracer()
data = torch.rand(2, 3, 224, 224) data = torch.rand(2, 3, 224, 224)
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
if model_cls in [tm.convnext_small, tm.efficientnet_b0]: if model_cls in RANDOMIZED_MODELS:
# remove the impact of randomicity # remove the impact of randomicity
model = model_cls(stochastic_depth_prob=0) model = model_cls(stochastic_depth_prob=0)
else: else: