From 54ee8d1254ed60da5f80e5a0d21e510fd495d725 Mon Sep 17 00:00:00 2001 From: Xu Kai Date: Wed, 9 Mar 2022 17:28:17 +0800 Subject: [PATCH] Fix/format colossalai/engine/paramhooks/(#350) --- colossalai/engine/paramhooks/__init__.py | 3 ++- colossalai/engine/paramhooks/_param_hookmgr.py | 10 ++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/colossalai/engine/paramhooks/__init__.py b/colossalai/engine/paramhooks/__init__.py index 6bafd7b91..7e4239937 100644 --- a/colossalai/engine/paramhooks/__init__.py +++ b/colossalai/engine/paramhooks/__init__.py @@ -1,2 +1,3 @@ from ._param_hookmgr import BaseParamHookMgr -__all__ = ["BaseParamHookMgr"] \ No newline at end of file + +__all__ = ["BaseParamHookMgr"] diff --git a/colossalai/engine/paramhooks/_param_hookmgr.py b/colossalai/engine/paramhooks/_param_hookmgr.py index 596a965eb..a1b995ccd 100644 --- a/colossalai/engine/paramhooks/_param_hookmgr.py +++ b/colossalai/engine/paramhooks/_param_hookmgr.py @@ -2,7 +2,9 @@ from typing import Callable, List import torch import functools + class BaseParamHookMgr(object): + def __init__(self, param_list: List[torch.nn.Parameter]) -> None: r""" register backward hook on every parameters of module @@ -10,17 +12,17 @@ class BaseParamHookMgr(object): self._param_list = param_list self._hook_list = [] - def register_backward_hooks(self, hook_call : Callable) -> None: + def register_backward_hooks(self, hook_call: Callable) -> None: r""" - The hook_call will be called every time a gradient with respect to the a param in self.param_list - is computed. + The hook_call will be called every time a gradient with respect to the a param in self.param_list + is computed. The hook should have the following signature: ``` hook(param, grad) -> Tensor or None ``` """ if not torch.is_grad_enabled(): - return # don't register grad hooks if grad isn't enabled + return # don't register grad hooks if grad isn't enabled for p in self._param_list: if p.requires_grad and not hasattr(p, '_base_param_hook'): handle = p.register_hook(functools.partial(hook_call, p))