From 75abc75c15f5c212c285808b718ee5cd60372068 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 18 Jul 2022 11:41:27 +0800 Subject: [PATCH] [fx] fixed compatiblity issue with torch 1.10 (#1331) --- colossalai/fx/passes/split_module.py | 12 ++++++++---- .../meta_patch/patched_function/arithmetic.py | 1 + colossalai/fx/tracer/tracer.py | 3 +++ .../test_pipeline/test_timm_model/test_timm.py | 6 +----- .../test_torchvision/test_torchvision.py | 14 +++++++------- .../test_timm_model/test_timm_model.py | 6 +----- .../test_torchvision_model.py | 18 +++++++++++------- 7 files changed, 32 insertions(+), 28 deletions(-) diff --git a/colossalai/fx/passes/split_module.py b/colossalai/fx/passes/split_module.py index 7b5cbb3cd..8671855f4 100644 --- a/colossalai/fx/passes/split_module.py +++ b/colossalai/fx/passes/split_module.py @@ -2,6 +2,7 @@ import torch from torch.fx.graph_module import GraphModule from typing import Callable, List, Dict, Any, Optional from torch.fx._compatibility import compatibility +from packaging import version import inspect @@ -233,10 +234,13 @@ def split_module( base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} for node in m.graph.nodes: if node.op == 'placeholder': - default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty - base_mod_env[node.name] = base_mod_graph.placeholder(node.name, - type_expr=node.type, - default_value=default_value) + if version.parse(torch.__version__) < version.parse('1.11.0'): + base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type) + else: + default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty + base_mod_env[node.name] = base_mod_graph.placeholder(node.name, + type_expr=node.type, + default_value=default_value) base_mod_env[node.name].meta = node.meta.copy() # Do some things iterating over the partitions in topological order again: diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py index 3077262db..f15621477 100644 --- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py +++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py @@ -3,6 +3,7 @@ from ..registry import meta_patched_function @meta_patched_function.register(torch.matmul) +@meta_patched_function.register('matmul') # for built-in op @ def torch_matmul(input, other, *, out=None): # copied from huggingface.utils.fx d1 = input.dim() diff --git a/colossalai/fx/tracer/tracer.py b/colossalai/fx/tracer/tracer.py index 33e14d57c..5b7a1eced 100644 --- a/colossalai/fx/tracer/tracer.py +++ b/colossalai/fx/tracer/tracer.py @@ -96,6 +96,9 @@ class ColoTracer(Tracer): # fetch patched function if meta_patched_function.has(target): meta_target = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + # use name for some builtin op like @ (matmul) + meta_target = meta_patched_function.get(target.__name__) else: meta_target = target diff --git a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py index da3843a27..c9ca452c4 100644 --- a/tests/test_fx/test_pipeline/test_timm_model/test_timm.py +++ b/tests/test_fx/test_pipeline/test_timm_model/test_timm.py @@ -1,9 +1,5 @@ import torch -import pytest -try: - import timm.models as tm -except: - pass +import timm.models as tm from timm_utils import split_model_and_compare_output diff --git a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py index c03121063..b308d99c2 100644 --- a/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py +++ b/tests/test_fx/test_pipeline/test_torchvision/test_torchvision.py @@ -1,16 +1,13 @@ import torch -try: - import torchvision.models as tm -except: - pass +import torchvision +import torchvision.models as tm from colossalai.fx import ColoTracer from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass from torch.fx import GraphModule - +from packaging import version import random import numpy as np import inspect -import pytest MANUAL_SEED = 0 random.seed(MANUAL_SEED) @@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True def test_torchvision_models(): MODEL_LIST = [ tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, - tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5 + tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5 ] + if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) + tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 5e2c40cac..a228e6c2e 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -1,9 +1,5 @@ import torch -import pytest -try: - import timm.models as tm -except: - pass +import timm.models as tm from colossalai.fx import ColoTracer from torch.fx import GraphModule diff --git a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py index 7360bd885..046a0dabe 100644 --- a/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py +++ b/tests/test_fx/test_tracer/test_torchvision_model/test_torchvision_model.py @@ -1,9 +1,7 @@ import torch -import pytest -try: - import torchvision.models as tm -except: - pass +import torchvision +import torchvision.models as tm +from packaging import version from colossalai.fx import ColoTracer from torch.fx import GraphModule @@ -11,16 +9,22 @@ from torch.fx import GraphModule def test_torchvision_models(): MODEL_LIST = [ tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2, - tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0 + tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0 ] + RANDOMIZED_MODELS = [tm.efficientnet_b0] + + if version.parse(torchvision.__version__) >= version.parse('0.12.0'): + MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small]) + RANDOMIZED_MODELS.append(tm.convnext_small) + torch.backends.cudnn.deterministic = True tracer = ColoTracer() data = torch.rand(2, 3, 224, 224) for model_cls in MODEL_LIST: - if model_cls in [tm.convnext_small, tm.efficientnet_b0]: + if model_cls in RANDOMIZED_MODELS: # remove the impact of randomicity model = model_cls(stochastic_depth_prob=0) else: