mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-02 09:38:05 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -22,7 +22,7 @@ class ShardGradMemTracerHook(BaseOpHook):
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, '_sharded_grad')
|
||||
assert hasattr(param, "_sharded_grad")
|
||||
param._sharded_grad.setup()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
|
@@ -19,25 +19,25 @@ class ShardParamHook(BaseOpHook):
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.gather()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
assert hasattr(param, "ca_attr")
|
||||
param.ca_attr.shard()
|
||||
param.data = param.ca_attr.payload()
|
||||
|
||||
|
@@ -15,8 +15,7 @@ class TrainingPhase(Enum):
|
||||
BACKWARD = 1
|
||||
|
||||
|
||||
class GradMemStats():
|
||||
|
||||
class GradMemStats:
|
||||
def __init__(self) -> None:
|
||||
self.unreleased_grad_flag = {}
|
||||
self.unreleased_grad_volume = 0
|
||||
@@ -26,8 +25,7 @@ class GradMemStats():
|
||||
self.unreleased_grad_volume = 0
|
||||
|
||||
|
||||
class GradMemTracerHook():
|
||||
|
||||
class GradMemTracerHook:
|
||||
def __init__(self, grad_stats: GradMemStats):
|
||||
self.grad_hook_list = []
|
||||
self._grad_stats = grad_stats
|
||||
@@ -50,7 +48,6 @@ class GradMemTracerHook():
|
||||
|
||||
|
||||
class ParamMemTracerHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self, memstats: MemStats, gradstats: GradMemStats) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
@@ -79,10 +76,9 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||
if cur_dev == "cpu":
|
||||
if p.grad is not None and p.grad.device.type == "cpu":
|
||||
raise NotImplementedError("Only run in forward propagation")
|
||||
p.data = torch.empty(p.data.shape,
|
||||
device="cuda",
|
||||
dtype=p.data.dtype,
|
||||
requires_grad=p.data.requires_grad)
|
||||
p.data = torch.empty(
|
||||
p.data.shape, device="cuda", dtype=p.data.dtype, requires_grad=p.data.requires_grad
|
||||
)
|
||||
elif cur_dev == "cuda":
|
||||
alloc_storage(p.data)
|
||||
|
||||
|
@@ -48,7 +48,6 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
|
||||
|
||||
|
||||
class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, outputs):
|
||||
ctx.module = module
|
||||
@@ -64,7 +63,6 @@ class PreBackwardFunction(torch.autograd.Function):
|
||||
|
||||
|
||||
class PostBackwardFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, module, pre_backward_function, output):
|
||||
ctx.module = module
|
||||
@@ -84,16 +82,15 @@ class PostBackwardFunction(torch.autograd.Function):
|
||||
return (None, None) + args
|
||||
|
||||
|
||||
def register_ophooks_recursively(module: torch.nn.Module,
|
||||
ophook_list: List[BaseOpHook],
|
||||
name: str = "",
|
||||
filter_fn: Optional[Callable] = None):
|
||||
def register_ophooks_recursively(
|
||||
module: torch.nn.Module, ophook_list: List[BaseOpHook], name: str = "", filter_fn: Optional[Callable] = None
|
||||
):
|
||||
r"""Recursively register pre/post hooks for all submodules in the module in FWD and BWD."""
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert isinstance(ophook_list, (list, tuple))
|
||||
assert len(ophook_list) > 0, 'expected at least 1 hook in the argument ophook_list but found 0'
|
||||
assert len(ophook_list) > 0, "expected at least 1 hook in the argument ophook_list but found 0"
|
||||
for hook in ophook_list:
|
||||
assert (isinstance(hook, BaseOpHook))
|
||||
assert isinstance(hook, BaseOpHook)
|
||||
|
||||
# Add hooks for submodules
|
||||
for child_name, child in module.named_children():
|
||||
@@ -118,7 +115,6 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
hook.post_fwd_exec(submodule, *args)
|
||||
|
||||
def _pre_backward_module_hook(submodule, inputs, output):
|
||||
|
||||
def _run_before_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
@@ -127,7 +123,6 @@ def register_ophooks_recursively(module: torch.nn.Module,
|
||||
return _apply_to_tensors_only(submodule, PreBackwardFunction, _run_before_backward_function, output)
|
||||
|
||||
def _post_backward_module_hook(submodule, inputs):
|
||||
|
||||
def _run_after_backward_function(submodule):
|
||||
for hook in ophook_list:
|
||||
assert isinstance(submodule, torch.nn.Module)
|
||||
|
Reference in New Issue
Block a user