[misc] update pre-commit and run all files (#4752)

* [misc] update pre-commit

* [misc] run pre-commit

* [misc] remove useless configuration files

* [misc] ignore cuda for clang-format
This commit is contained in:
Hongxin Liu
2023-09-19 14:20:26 +08:00
committed by GitHub
parent 3c6b831c26
commit 079bf3cb26
1268 changed files with 50037 additions and 38444 deletions

View File

@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
try:
from torch.fx.graph import _PyTreeCodeGen
SUPPORT_PT_CODEGEN = True
except ImportError:
SUPPORT_PT_CODEGEN = False
@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
# This is a copy of torch.fx.graph_module._WrappedCall.
# It should be removed when we stop supporting torch < 1.12.0.
class _WrappedCall:
def __init__(self, cls, cls_call):
self.cls = cls
self.cls_call = cls_call
@@ -50,12 +50,14 @@ class _WrappedCall:
# constituent substrings of the error message
tb_repr = traceback.format_exc()
custom_msg = ("Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:")
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
custom_msg = (
"Call using an FX-traced Module, "
f"line {err_lineno} of the traced Module's "
"generated forward function:"
)
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
marker = "~" * err_line_len + "~~~ <--- HERE"
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
# joined message
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
@@ -65,11 +67,14 @@ class _WrappedCall:
if self.cls_call is not None:
return self.cls_call(obj, *args, **kwargs)
else:
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
except Exception as e:
assert e.__traceback__
topmost_framesummary: traceback.FrameSummary = \
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
traceback.walk_tb(e.__traceback__)
)[
-1
] # type: ignore[arg-type]
if "eval_with_key" in topmost_framesummary.filename:
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
raise e.with_traceback(None)
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
code.
"""
def __init__(self,
root: Union[torch.nn.Module, Dict[str, Any]],
graph: torch.fx.Graph,
class_name: str = 'GraphModule'):
def __init__(
self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
):
super().__init__(root, graph, class_name)
def bind(self, ckpt_def, globals):
@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module='self')
python_code = self._graph.python_code(root_module="self")
self._code = python_code.src
# To split ckpt functions code and forward code
@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule):
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
cls_call = cls.__call__ if "__call__" in vars(cls) else None
if '_wrapped_call' not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
if "_wrapped_call" not in vars(cls):
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
def call_wrapped(self, *args, **kwargs):
return self._wrapped_call(self, *args, **kwargs)
@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
"""
folder = Path(folder)
Path(folder).mkdir(exist_ok=True)
torch.save(self.state_dict(), folder / 'state_dict.pt')
torch.save(self.state_dict(), folder / "state_dict.pt")
tab = " " * 4
# we add import colossalai here
@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
for module_name, module in self.named_children():
module_str = _gen_model_repr(module_name, module)
if module_str is None:
module_file = folder / f'{module_name}.pt'
module_file = folder / f"{module_name}.pt"
torch.save(module, module_file)
blobified_modules.append(module_name)
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
module_str = f"torch.load(r'{module_file}') # {module_repr}"
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
model_str += f"{_addindent(self.code, 4)}\n"
module_file = folder / 'module.py'
module_file = folder / "module.py"
module_file.write_text(model_str)
init_file = folder / '__init__.py'
init_file.write_text('from .module import *')
init_file = folder / "__init__.py"
init_file.write_text("from .module import *")
if len(blobified_modules) > 0:
warnings.warn("Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}")
warnings.warn(
"Was not able to save the following children modules as reprs -"
f"saved as pickled files instead: {blobified_modules}"
)