mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 02:26:51 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -13,6 +13,7 @@ from torch.fx.graph import PythonCode
|
||||
|
||||
try:
|
||||
from torch.fx.graph import _PyTreeCodeGen
|
||||
|
||||
SUPPORT_PT_CODEGEN = True
|
||||
except ImportError:
|
||||
SUPPORT_PT_CODEGEN = False
|
||||
@@ -24,7 +25,6 @@ from torch.nn.modules.module import _addindent
|
||||
# This is a copy of torch.fx.graph_module._WrappedCall.
|
||||
# It should be removed when we stop supporting torch < 1.12.0.
|
||||
class _WrappedCall:
|
||||
|
||||
def __init__(self, cls, cls_call):
|
||||
self.cls = cls
|
||||
self.cls_call = cls_call
|
||||
@@ -50,12 +50,14 @@ class _WrappedCall:
|
||||
|
||||
# constituent substrings of the error message
|
||||
tb_repr = traceback.format_exc()
|
||||
custom_msg = ("Call using an FX-traced Module, "
|
||||
f"line {err_lineno} of the traced Module's "
|
||||
"generated forward function:")
|
||||
before_err = "".join(all_src_lines[err_lineno - 2:err_lineno])
|
||||
custom_msg = (
|
||||
"Call using an FX-traced Module, "
|
||||
f"line {err_lineno} of the traced Module's "
|
||||
"generated forward function:"
|
||||
)
|
||||
before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno])
|
||||
marker = "~" * err_line_len + "~~~ <--- HERE"
|
||||
err_and_after_err = "\n".join(all_src_lines[err_lineno:err_lineno + 2])
|
||||
err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2])
|
||||
|
||||
# joined message
|
||||
return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err])
|
||||
@@ -65,11 +67,14 @@ class _WrappedCall:
|
||||
if self.cls_call is not None:
|
||||
return self.cls_call(obj, *args, **kwargs)
|
||||
else:
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
|
||||
except Exception as e:
|
||||
assert e.__traceback__
|
||||
topmost_framesummary: traceback.FrameSummary = \
|
||||
traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] # type: ignore[arg-type]
|
||||
topmost_framesummary: traceback.FrameSummary = traceback.StackSummary.extract(
|
||||
traceback.walk_tb(e.__traceback__)
|
||||
)[
|
||||
-1
|
||||
] # type: ignore[arg-type]
|
||||
if "eval_with_key" in topmost_framesummary.filename:
|
||||
print(_WrappedCall._generate_error_message(topmost_framesummary), file=sys.stderr)
|
||||
raise e.with_traceback(None)
|
||||
@@ -99,10 +104,9 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
code.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: torch.fx.Graph,
|
||||
class_name: str = 'GraphModule'):
|
||||
def __init__(
|
||||
self, root: Union[torch.nn.Module, Dict[str, Any]], graph: torch.fx.Graph, class_name: str = "GraphModule"
|
||||
):
|
||||
super().__init__(root, graph, class_name)
|
||||
|
||||
def bind(self, ckpt_def, globals):
|
||||
@@ -134,7 +138,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
if SUPPORT_PT_CODEGEN and isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module='self')
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
|
||||
# To split ckpt functions code and forward code
|
||||
@@ -157,8 +161,8 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
|
||||
cls_call = cls.__call__ if "__call__" in vars(cls) else None
|
||||
|
||||
if '_wrapped_call' not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
if "_wrapped_call" not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
return self._wrapped_call(self, *args, **kwargs)
|
||||
@@ -182,7 +186,7 @@ class ColoGraphModule(torch.fx.GraphModule):
|
||||
"""
|
||||
folder = Path(folder)
|
||||
Path(folder).mkdir(exist_ok=True)
|
||||
torch.save(self.state_dict(), folder / 'state_dict.pt')
|
||||
torch.save(self.state_dict(), folder / "state_dict.pt")
|
||||
tab = " " * 4
|
||||
|
||||
# we add import colossalai here
|
||||
@@ -208,10 +212,10 @@ class {module_name}(torch.nn.Module):
|
||||
for module_name, module in self.named_children():
|
||||
module_str = _gen_model_repr(module_name, module)
|
||||
if module_str is None:
|
||||
module_file = folder / f'{module_name}.pt'
|
||||
module_file = folder / f"{module_name}.pt"
|
||||
torch.save(module, module_file)
|
||||
blobified_modules.append(module_name)
|
||||
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
|
||||
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
|
||||
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
|
||||
@@ -228,12 +232,14 @@ class {module_name}(torch.nn.Module):
|
||||
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
model_str += f"{_addindent(self.code, 4)}\n"
|
||||
|
||||
module_file = folder / 'module.py'
|
||||
module_file = folder / "module.py"
|
||||
module_file.write_text(model_str)
|
||||
|
||||
init_file = folder / '__init__.py'
|
||||
init_file.write_text('from .module import *')
|
||||
init_file = folder / "__init__.py"
|
||||
init_file.write_text("from .module import *")
|
||||
|
||||
if len(blobified_modules) > 0:
|
||||
warnings.warn("Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}")
|
||||
warnings.warn(
|
||||
"Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}"
|
||||
)
|
||||
|
Reference in New Issue
Block a user