mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[fx] fixed compatiblity issue with torch 1.10 (#1331)
This commit is contained in:
@@ -2,6 +2,7 @@ import torch
|
||||
from torch.fx.graph_module import GraphModule
|
||||
from typing import Callable, List, Dict, Any, Optional
|
||||
from torch.fx._compatibility import compatibility
|
||||
from packaging import version
|
||||
import inspect
|
||||
|
||||
|
||||
@@ -233,10 +234,13 @@ def split_module(
|
||||
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
|
||||
for node in m.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
||||
type_expr=node.type,
|
||||
default_value=default_value)
|
||||
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
|
||||
else:
|
||||
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
|
||||
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
|
||||
type_expr=node.type,
|
||||
default_value=default_value)
|
||||
base_mod_env[node.name].meta = node.meta.copy()
|
||||
|
||||
# Do some things iterating over the partitions in topological order again:
|
||||
|
||||
@@ -3,6 +3,7 @@ from ..registry import meta_patched_function
|
||||
|
||||
|
||||
@meta_patched_function.register(torch.matmul)
|
||||
@meta_patched_function.register('matmul') # for built-in op @
|
||||
def torch_matmul(input, other, *, out=None):
|
||||
# copied from huggingface.utils.fx
|
||||
d1 = input.dim()
|
||||
|
||||
@@ -96,6 +96,9 @@ class ColoTracer(Tracer):
|
||||
# fetch patched function
|
||||
if meta_patched_function.has(target):
|
||||
meta_target = meta_patched_function.get(target)
|
||||
elif meta_patched_function.has(target.__name__):
|
||||
# use name for some builtin op like @ (matmul)
|
||||
meta_target = meta_patched_function.get(target.__name__)
|
||||
else:
|
||||
meta_target = target
|
||||
|
||||
|
||||
Reference in New Issue
Block a user