mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 19:13:01 +00:00
[misc] update pre-commit and run all files (#4752)
* [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format
This commit is contained in:
@@ -5,7 +5,6 @@ import torch
|
||||
|
||||
|
||||
class BaseParamHookMgr(object):
|
||||
|
||||
def __init__(self, param_list: List[torch.nn.Parameter]) -> None:
|
||||
r"""
|
||||
register backward hook on every parameters of module
|
||||
@@ -23,9 +22,9 @@ class BaseParamHookMgr(object):
|
||||
```
|
||||
"""
|
||||
if not torch.is_grad_enabled():
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
return # don't register grad hooks if grad isn't enabled
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and not hasattr(p, '_base_param_hook'):
|
||||
if p.requires_grad and not hasattr(p, "_base_param_hook"):
|
||||
handle = p.register_hook(functools.partial(hook_call, p))
|
||||
p._base_param_hook = handle
|
||||
|
||||
@@ -35,5 +34,5 @@ class BaseParamHookMgr(object):
|
||||
"""
|
||||
|
||||
for p in self._param_list:
|
||||
if p.requires_grad and hasattr(p, '_base_param_hook'):
|
||||
if p.requires_grad and hasattr(p, "_base_param_hook"):
|
||||
p._base_param_hook.remove()
|
||||
|
Reference in New Issue
Block a user