diff --git a/tests/test_gemini/test_param_op.py b/tests/test_gemini/test_param_op.py new file mode 100644 index 000000000..ed9d51d9a --- /dev/null +++ b/tests/test_gemini/test_param_op.py @@ -0,0 +1,95 @@ +from colossalai.gemini.paramhooks import BaseParamHookMgr +from torch import nn +import torch +import torch.nn.functional as F +import copy + + +class SubNet(nn.Module): + + def __init__(self, out_features) -> None: + super().__init__() + self.bias = nn.Parameter(torch.zeros(out_features)) + + def forward(self, x, weight): + return F.linear(x, weight, self.bias) + + +class Net(nn.Module): + + def __init__(self, checkpoint=False) -> None: + super().__init__() + self.fc1 = nn.Linear(5, 5) + self.sub_fc = SubNet(5) + self.fc2 = nn.Linear(5, 1) + + def forward(self, x): + x = self.fc1(x) + x = self.sub_fc(x, self.fc1.weight) + x = self.fc1(x) + x = self.fc2(x) + return x + + +def net_data(): + return (torch.randn(2, 5, dtype=torch.float, device='cuda'),) + + +def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: + if loose: + return torch.allclose(tensor_a, tensor_b, atol=1e-3, rtol=1e-3) + return torch.allclose(tensor_a, tensor_b) + + +def test_base_param_hook(): + torch.manual_seed(0) + model = Net(checkpoint=True).cuda() + model.train() + inputs = net_data() + + def run_model(model, inputs, use_param_hook=False): + if use_param_hook: + + class HooKWrapper: + + def __init__(self) -> None: + self.hook_triggered_times = 0 + + def wrapper_func(self): + + def hook(param, grad) -> torch.Tensor or None: + self.hook_triggered_times += 1 + return grad + + return hook + + hookwrapper = HooKWrapper() + param_list = [p for p in model.parameters()] + hook_mgr = BaseParamHookMgr(param_list) + hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) + + model.zero_grad(set_to_none=True) + + with torch.cuda.amp.autocast(): + y = model(*inputs) + loss = y.sum() + loss.backward() + + if use_param_hook: + hook_mgr.remove_hooks() + return hookwrapper.hook_triggered_times + + model_copy = copy.deepcopy(model) + + run_model(model, inputs, False) + ret2 = run_model(model_copy, inputs, True) + + # Make sure param hook has only be fired once in case of parameter sharing + assert ret2 == len(list(model.parameters())) + + for p, p_copy in zip(model.parameters(), model_copy.parameters()): + assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" + + +if __name__ == '__main__': + test_base_param_hook()