[fx] fixed compatiblity issue with torch 1.10 (#1331)

This commit is contained in:
Frank Lee
2022-07-18 11:41:27 +08:00
committed by GitHub
parent 069d6fdc84
commit 75abc75c15
7 changed files with 32 additions and 28 deletions

View File

@@ -2,6 +2,7 @@ import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Any, Optional
from torch.fx._compatibility import compatibility
from packaging import version
import inspect
@@ -233,10 +234,13 @@ def split_module(
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
if version.parse(torch.__version__) < version.parse('1.11.0'):
base_mod_env[node.name] = base_mod_graph.placeholder(node.name, type_expr=node.type)
else:
default_value = node.args[0] if len(node.args) > 0 else inspect.Signature.empty
base_mod_env[node.name] = base_mod_graph.placeholder(node.name,
type_expr=node.type,
default_value=default_value)
base_mod_env[node.name].meta = node.meta.copy()
# Do some things iterating over the partitions in topological order again: