mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -1,32 +1,35 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.modules.module import _addindent
|
||||
|
||||
try:
|
||||
from torch.fx.graph import Graph, PythonCode, _custom_builtins, _is_from_torch, _PyTreeCodeGen
|
||||
from torch.fx.graph_module import GraphModule, _EvalCacheLoader, _exec_with_source, _forward_from_src, _WrappedCall
|
||||
from torch.fx.graph import Graph, PythonCode, _PyTreeCodeGen
|
||||
from torch.fx.graph_module import GraphModule, _exec_with_source, _forward_from_src, _WrappedCall
|
||||
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
|
||||
COLOGM = True
|
||||
except:
|
||||
from torch.fx.graph import Graph
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
COLOGM = False
|
||||
|
||||
if COLOGM:
|
||||
|
||||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: Graph,
|
||||
class_name: str = 'GraphModule',
|
||||
ckpt_codegen: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
root: Union[torch.nn.Module, Dict[str, Any]],
|
||||
graph: Graph,
|
||||
class_name: str = "GraphModule",
|
||||
ckpt_codegen: bool = True,
|
||||
):
|
||||
if ckpt_codegen:
|
||||
graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
super().__init__(root, graph, class_name)
|
||||
@@ -60,7 +63,7 @@ if COLOGM:
|
||||
if isinstance(self._graph._codegen, _PyTreeCodeGen):
|
||||
self._in_spec = self._graph._codegen.pytree_info.in_spec
|
||||
self._out_spec = self._graph._codegen.pytree_info.out_spec
|
||||
python_code = self._graph.python_code(root_module='self')
|
||||
python_code = self._graph.python_code(root_module="self")
|
||||
self._code = python_code.src
|
||||
|
||||
# To split ckpt functions code and forward code
|
||||
@@ -83,8 +86,8 @@ if COLOGM:
|
||||
# bypass patching of torch.nn.Module.__call__ done while symbolic tracing.
|
||||
cls_call = cls.__call__ if "__call__" in vars(cls) else None
|
||||
|
||||
if '_wrapped_call' not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
if "_wrapped_call" not in vars(cls):
|
||||
cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined]
|
||||
|
||||
def call_wrapped(self, *args, **kwargs):
|
||||
return self._wrapped_call(self, *args, **kwargs)
|
||||
@@ -108,7 +111,7 @@ if COLOGM:
|
||||
"""
|
||||
folder = Path(folder)
|
||||
Path(folder).mkdir(exist_ok=True)
|
||||
torch.save(self.state_dict(), folder / 'state_dict.pt')
|
||||
torch.save(self.state_dict(), folder / "state_dict.pt")
|
||||
tab = " " * 4
|
||||
|
||||
# we add import colossalai here
|
||||
@@ -125,7 +128,13 @@ class {module_name}(torch.nn.Module):
|
||||
|
||||
def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]:
|
||||
safe_reprs = [
|
||||
nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
||||
nn.Linear,
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
]
|
||||
if type(module) in safe_reprs:
|
||||
return f"{module.__repr__()}"
|
||||
@@ -136,10 +145,10 @@ class {module_name}(torch.nn.Module):
|
||||
for module_name, module in self.named_children():
|
||||
module_str = _gen_model_repr(module_name, module)
|
||||
if module_str is None:
|
||||
module_file = folder / f'{module_name}.pt'
|
||||
module_file = folder / f"{module_name}.pt"
|
||||
torch.save(module, module_file)
|
||||
blobified_modules.append(module_name)
|
||||
module_repr = module.__repr__().replace('\r', ' ').replace('\n', ' ')
|
||||
module_repr = module.__repr__().replace("\r", " ").replace("\n", " ")
|
||||
module_str = f"torch.load(r'{module_file}') # {module_repr}"
|
||||
model_str += f"{tab*2}self.{module_name} = {module_str}\n"
|
||||
|
||||
@@ -156,19 +165,20 @@ class {module_name}(torch.nn.Module):
|
||||
model_str += f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n"
|
||||
model_str += f"{_addindent(self.code, 4)}\n"
|
||||
|
||||
module_file = folder / 'module.py'
|
||||
module_file = folder / "module.py"
|
||||
module_file.write_text(model_str)
|
||||
|
||||
init_file = folder / '__init__.py'
|
||||
init_file.write_text('from .module import *')
|
||||
init_file = folder / "__init__.py"
|
||||
init_file.write_text("from .module import *")
|
||||
|
||||
if len(blobified_modules) > 0:
|
||||
warnings.warn("Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}")
|
||||
warnings.warn(
|
||||
"Was not able to save the following children modules as reprs -"
|
||||
f"saved as pickled files instead: {blobified_modules}"
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
class ColoGraphModule(GraphModule):
|
||||
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = 'GraphModule'):
|
||||
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph, class_name: str = "GraphModule"):
|
||||
super().__init__(root, graph, class_name)
|
||||
|
Reference in New Issue
Block a user