diff --git a/colossalai/engine/paramhooks/__init__.py b/colossalai/engine/paramhooks/__init__.py new file mode 100644 index 000000000..6bafd7b91 --- /dev/null +++ b/colossalai/engine/paramhooks/__init__.py @@ -0,0 +1,2 @@ +from ._param_hookmgr import BaseParamHookMgr +__all__ = ["BaseParamHookMgr"] \ No newline at end of file diff --git a/colossalai/engine/paramhooks/_param_hookmgr.py b/colossalai/engine/paramhooks/_param_hookmgr.py new file mode 100644 index 000000000..596a965eb --- /dev/null +++ b/colossalai/engine/paramhooks/_param_hookmgr.py @@ -0,0 +1,32 @@ +from typing import Callable, List +import torch +import functools + +class BaseParamHookMgr(object): + def __init__(self, param_list: List[torch.nn.Parameter]) -> None: + r""" + register backward hook on every parameters of module + """ + self._param_list = param_list + self._hook_list = [] + + def register_backward_hooks(self, hook_call : Callable) -> None: + r""" + The hook_call will be called every time a gradient with respect to the a param in self.param_list + is computed. + The hook should have the following signature: + ``` + hook(param, grad) -> Tensor or None + ``` + """ + if not torch.is_grad_enabled(): + return # don't register grad hooks if grad isn't enabled + for p in self._param_list: + if p.requires_grad and not hasattr(p, '_base_param_hook'): + handle = p.register_hook(functools.partial(hook_call, p)) + p._base_param_hook = handle + + def remove_hooks(self): + for p in self._param_list: + if p.requires_grad and hasattr(p, '_base_param_hook'): + p._base_param_hook.remove() diff --git a/tests/test_engine/test_engine/test_param_hook.py b/tests/test_engine/test_engine/test_param_hook.py new file mode 100644 index 000000000..54639157f --- /dev/null +++ b/tests/test_engine/test_engine/test_param_hook.py @@ -0,0 +1,86 @@ +import pytest +from colossalai.engine.paramhooks import BaseParamHookMgr +from torch import nn +import torch +import torch.nn.functional as F +import copy + +class SubNet(nn.Module): + def __init__(self, out_features) -> None: + super().__init__() + self.bias = nn.Parameter(torch.zeros(out_features)) + + def forward(self, x, weight): + return F.linear(x, weight, self.bias) + + +class Net(nn.Module): + def __init__(self, checkpoint=False) -> None: + super().__init__() + self.fc1 = nn.Linear(5, 5) + self.sub_fc = SubNet(5) + self.fc2 = nn.Linear(5, 1) + + def forward(self, x): + x = self.fc1(x) + x = self.sub_fc(x, self.fc1.weight) + x = self.fc1(x) + x = self.fc2(x) + return x + +def net_data(): + return (torch.randn(2, 5, dtype=torch.float, device='cuda'),) + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def test_base_param_hook(): + torch.manual_seed(0) + model = Net(checkpoint=True).cuda() + model.train() + inputs = net_data() + + def run_model(model, inputs, use_param_hook = False): + if use_param_hook: + class HooKWrapper: + def __init__(self) -> None: + self.hook_triggered_times = 0 + + def wrapper_func(self): + def hook(param, grad) -> torch.Tensor or None: + self.hook_triggered_times += 1 + return grad + return hook + + hookwrapper = HooKWrapper() + param_list = [p for p in model.parameters()] + hook_mgr = BaseParamHookMgr(param_list) + hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) + + model.zero_grad(set_to_none=True) + + with torch.cuda.amp.autocast(): + y = model(*inputs) + loss = y.sum() + loss.backward() + + if use_param_hook: + hook_mgr.remove_hooks() + return hookwrapper.hook_triggered_times + + model_copy = copy.deepcopy(model) + + run_model(model, inputs, False) + ret2 = run_model(model_copy, inputs, True) + + # Make sure param hook has only be fired once in case of parameter sharing + assert ret2 == len(list(model.parameters())) + + for p, p_copy in zip(model.parameters(), model_copy.parameters()): + assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" + +if __name__ == '__main__': + test_base_param_hook()