[FX] refactor experimental tracer and adapt it with hf models (#3157)

* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
This commit is contained in:
YuliangLiu0306
2023-03-22 10:40:33 +08:00
committed by GitHub
parent b429529365
commit f57d34958b
28 changed files with 1014 additions and 863 deletions

View File

@@ -1,4 +1,7 @@
import linecache
import os
import sys
import traceback
import warnings
from pathlib import Path
from typing import Any, Dict, Optional, Union
@@ -6,11 +9,74 @@ from typing import Any, Dict, Optional, Union
import torch
import torch.fx
import torch.nn as nn
from torch.fx.graph import PythonCode, _PyTreeCodeGen
from torch.fx.graph_module import _exec_with_source, _forward_from_src, _WrappedCall
from torch.fx.graph import PythonCode
try:
from torch.fx.graph import _PyTreeCodeGen
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
from torch.fx.graph_module import _exec_with_source, _forward_from_src
from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
# Previously, if an error occurred when valid
# symbolically-traced code was run with an invalid input, the
# user would see the source of the error as coming from
# `File "<eval_with_key_N">`, where N is some number. We use
# this function to generate a more informative error message. We
# return the traceback itself, a message explaining that the
# error occurred in a traced Module's generated forward
# function, and five lines of context surrounding the faulty
# line
@staticmethod
def _generate_error_message(frame_summary: traceback.FrameSummary) -> str:
# auxiliary variables (for readability)
err_lineno = frame_summary.lineno
assert err_lineno is not None
line = frame_summary.line
assert line is not None
err_line_len = len(line)
all_src_lines = linecache.getlines(frame_summary.filename)
# constituent substrings of the error message
tb_repr = traceback.format_exc()
custom_msg = ("Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:")
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
def __call__(self, obj, *args, **kwargs):
try:
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
topmost_framesummary: traceback.FrameSummary = \
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
else:
raise e
class ColoGraphModule(torch.fx.GraphModule):
"""
ColoGraphGraphModule is an nn.Module generated from an fx.Graph.
@@ -65,7 +131,7 @@ class ColoGraphModule(torch.fx.GraphModule):
called after editing the contained ``graph``, otherwise the generated
code of this ``GraphModule`` will be out of date.
"""
if isinstance(self._graph._codegen, _PyTreeCodeGen):
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')