mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 15:57:16 +00:00
[fx] fixed compatiblity issue with torch 1.10 (#1331)
This commit is contained in:
parent
069d6fdc84
commit
75abc75c15
@ -2,6 +2,7 @@ import torch
|
|||||||
from torch.fx.graph_module import GraphModule
|
from torch.fx.graph_module import GraphModule
|
||||||
from typing import Callable, List, Dict, Any, Optional
|
from typing import Callable, List, Dict, Any, Optional
|
||||||
from torch.fx._compatibility import compatibility
|
from torch.fx._compatibility import compatibility
|
||||||
|
from packaging import version
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
|
||||||
@ -233,6 +234,9 @@ def split_module(
|
|||||||
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||||
for node in m.graph.nodes:
|
for node in m.graph.nodes:
|
||||||
if node.op == 'placeholder':
|
if node.op == 'placeholder':
|
||||||
|
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
||||||
|
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
|
||||||
|
else:
|
||||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
||||||
type_expr=node.type,
|
type_expr=node.type,
|
||||||
|
@ -3,6 +3,7 @@ from ..registry import meta_patched_function
|
|||||||
|
|
||||||
|
|
||||||
@meta_patched_function.register(torch.matmul)
|
@meta_patched_function.register(torch.matmul)
|
||||||
|
@meta_patched_function.register('matmul') # for built-in op @
|
||||||
def torch_matmul(input, other, *, out=None):
|
def torch_matmul(input, other, *, out=None):
|
||||||
# copied from huggingface.utils.fx
|
# copied from huggingface.utils.fx
|
||||||
d1 = input.dim()
|
d1 = input.dim()
|
||||||
|
@ -96,6 +96,9 @@ class ColoTracer(Tracer):
|
|||||||
# fetch patched function
|
# fetch patched function
|
||||||
if meta_patched_function.has(target):
|
if meta_patched_function.has(target):
|
||||||
meta_target = meta_patched_function.get(target)
|
meta_target = meta_patched_function.get(target)
|
||||||
|
elif meta_patched_function.has(target.__name__):
|
||||||
|
# use name for some builtin op like @ (matmul)
|
||||||
|
meta_target = meta_patched_function.get(target.__name__)
|
||||||
else:
|
else:
|
||||||
meta_target = target
|
meta_target = target
|
||||||
|
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import timm.models as tm
|
||||||
try:
|
|
||||||
import timm.models as tm
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
from timm_utils import split_model_and_compare_output
|
from timm_utils import split_model_and_compare_output
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,16 +1,13 @@
|
|||||||
import torch
|
import torch
|
||||||
try:
|
import torchvision
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
except:
|
|
||||||
pass
|
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
from packaging import version
|
||||||
import random
|
import random
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import inspect
|
import inspect
|
||||||
import pytest
|
|
||||||
|
|
||||||
MANUAL_SEED = 0
|
MANUAL_SEED = 0
|
||||||
random.seed(MANUAL_SEED)
|
random.seed(MANUAL_SEED)
|
||||||
@ -22,9 +19,12 @@ torch.backends.cudnn.deterministic = True
|
|||||||
def test_torchvision_models():
|
def test_torchvision_models():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||||
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.efficientnet_b0, tm.mnasnet0_5
|
tm.regnet_x_16gf, tm.efficientnet_b0, tm.mnasnet0_5
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
|
||||||
|
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
data = torch.rand(2, 3, 224, 224)
|
data = torch.rand(2, 3, 224, 224)
|
||||||
|
|
||||||
|
@ -1,9 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import timm.models as tm
|
||||||
try:
|
|
||||||
import timm.models as tm
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import pytest
|
import torchvision
|
||||||
try:
|
import torchvision.models as tm
|
||||||
import torchvision.models as tm
|
from packaging import version
|
||||||
except:
|
|
||||||
pass
|
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
@ -11,16 +9,22 @@ from torch.fx import GraphModule
|
|||||||
def test_torchvision_models():
|
def test_torchvision_models():
|
||||||
MODEL_LIST = [
|
MODEL_LIST = [
|
||||||
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
tm.vgg11, tm.resnet18, tm.densenet121, tm.mobilenet_v3_small, tm.resnext50_32x4d, tm.wide_resnet50_2,
|
||||||
tm.regnet_x_16gf, tm.vit_b_16, tm.convnext_small, tm.mnasnet0_5, tm.efficientnet_b0
|
tm.regnet_x_16gf, tm.mnasnet0_5, tm.efficientnet_b0
|
||||||
]
|
]
|
||||||
|
|
||||||
|
RANDOMIZED_MODELS = [tm.efficientnet_b0]
|
||||||
|
|
||||||
|
if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
|
||||||
|
MODEL_LIST.extend([tm.vit_b_16, tm.convnext_small])
|
||||||
|
RANDOMIZED_MODELS.append(tm.convnext_small)
|
||||||
|
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
data = torch.rand(2, 3, 224, 224)
|
data = torch.rand(2, 3, 224, 224)
|
||||||
|
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
if model_cls in [tm.convnext_small, tm.efficientnet_b0]:
|
if model_cls in RANDOMIZED_MODELS:
|
||||||
# remove the impact of randomicity
|
# remove the impact of randomicity
|
||||||
model = model_cls(stochastic_depth_prob=0)
|
model = model_cls(stochastic_depth_prob=0)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user