diff --git a/tests/components_to_test/__init__.py b/tests/components_to_test/__init__.py index f87d35ff9..02f877c6a 100644 --- a/tests/components_to_test/__init__.py +++ b/tests/components_to_test/__init__.py @@ -1 +1 @@ -from . import repeated_computed_layer, resnet, nested_model, bert, no_leaf_module, simple_net, gpt +from . import bert, gpt, inline_op_model, nested_model, no_leaf_module, repeated_computed_layer, resnet, simple_net diff --git a/tests/components_to_test/inline_op_model.py b/tests/components_to_test/inline_op_model.py new file mode 100644 index 000000000..4fb7e55b2 --- /dev/null +++ b/tests/components_to_test/inline_op_model.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.nn import CheckpointModule + +from .registry import non_distributed_component_funcs +from .utils.dummy_data_generator import DummyDataGenerator + + +class InlineOpModule(CheckpointModule): + """ + a module with inline Ops + """ + + def __init__(self, checkpoint=False) -> None: + super().__init__(checkpoint=checkpoint) + self.proj1 = nn.Linear(4, 8) + self.weight = nn.Parameter(torch.randn(8, 8)) + self.proj2 = nn.Linear(8, 4) + + def forward(self, x): + x = self.proj1(x) + # inline add_ + x.add_(10) + x = F.linear(x, self.weight) + # inline relu_ + x = torch.relu_(x) + x = self.proj2(x) + return x + + +class DummyDataLoader(DummyDataGenerator): + + def generate(self): + data = torch.rand(16, 4) + label = torch.randint(low=0, high=2, size=(16,)) + return data, label + + +@non_distributed_component_funcs.register(name='inline_op_module') +def get_training_components(): + + def model_builder(checkpoint=True): + return InlineOpModule(checkpoint) + + trainloader = DummyDataLoader() + testloader = DummyDataLoader() + + criterion = torch.nn.CrossEntropyLoss() + from colossalai.nn.optimizer import HybridAdam + return model_builder, trainloader, testloader, HybridAdam, criterion diff --git a/tests/test_gemini/test_param_op.py b/tests/test_gemini/test_param_op.py index ed9d51d9a..f8f7c34d0 100644 --- a/tests/test_gemini/test_param_op.py +++ b/tests/test_gemini/test_param_op.py @@ -1,38 +1,9 @@ -from colossalai.gemini.paramhooks import BaseParamHookMgr -from torch import nn -import torch -import torch.nn.functional as F import copy +import torch -class SubNet(nn.Module): - - def __init__(self, out_features) -> None: - super().__init__() - self.bias = nn.Parameter(torch.zeros(out_features)) - - def forward(self, x, weight): - return F.linear(x, weight, self.bias) - - -class Net(nn.Module): - - def __init__(self, checkpoint=False) -> None: - super().__init__() - self.fc1 = nn.Linear(5, 5) - self.sub_fc = SubNet(5) - self.fc2 = nn.Linear(5, 1) - - def forward(self, x): - x = self.fc1(x) - x = self.sub_fc(x, self.fc1.weight) - x = self.fc1(x) - x = self.fc2(x) - return x - - -def net_data(): - return (torch.randn(2, 5, dtype=torch.float, device='cuda'),) +from colossalai.gemini.paramhooks import BaseParamHookMgr +from tests.components_to_test.registry import non_distributed_component_funcs def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> bool: @@ -41,54 +12,68 @@ def allclose(tensor_a: torch.Tensor, tensor_b: torch.Tensor, loose=False) -> boo return torch.allclose(tensor_a, tensor_b) +def run_model(model, inputs, label, criterion, use_param_hook=False): + if use_param_hook: + + class HooKWrapper: + + def __init__(self) -> None: + self.hook_triggered_times = 0 + + def wrapper_func(self): + + def hook(param, grad) -> torch.Tensor or None: + self.hook_triggered_times += 1 + return grad + + return hook + + hookwrapper = HooKWrapper() + param_list = [p for p in model.parameters()] + hook_mgr = BaseParamHookMgr(param_list) + hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) + + model.zero_grad(set_to_none=True) + + with torch.cuda.amp.autocast(): + if criterion: + y = model(inputs) + loss = criterion(y, label) + else: + loss = model(inputs, label) + loss = loss.float() + loss.backward() + + if use_param_hook: + hook_mgr.remove_hooks() + return hookwrapper.hook_triggered_times + + def test_base_param_hook(): - torch.manual_seed(0) - model = Net(checkpoint=True).cuda() - model.train() - inputs = net_data() + test_models = ['repeated_computed_layers', 'resnet18', 'no_leaf_module', 'inline_op_module'] + # test_models = ['bert'] - def run_model(model, inputs, use_param_hook=False): - if use_param_hook: + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() - class HooKWrapper: + torch.manual_seed(0) + model = model_builder(checkpoint=True).cuda() + model.train() - def __init__(self) -> None: - self.hook_triggered_times = 0 + for i, (inputs, label) in enumerate(train_dataloader): + if i > 0: + break + model_copy = copy.deepcopy(model) - def wrapper_func(self): + run_model(model, inputs.cuda(), label.cuda(), criterion, False) + ret2 = run_model(model_copy, inputs.cuda(), label.cuda(), criterion, True) - def hook(param, grad) -> torch.Tensor or None: - self.hook_triggered_times += 1 - return grad + # Make sure param hook has only be fired once in case of parameter sharing + assert ret2 == len(list(model.parameters())) - return hook - - hookwrapper = HooKWrapper() - param_list = [p for p in model.parameters()] - hook_mgr = BaseParamHookMgr(param_list) - hook_mgr.register_backward_hooks(hookwrapper.wrapper_func()) - - model.zero_grad(set_to_none=True) - - with torch.cuda.amp.autocast(): - y = model(*inputs) - loss = y.sum() - loss.backward() - - if use_param_hook: - hook_mgr.remove_hooks() - return hookwrapper.hook_triggered_times - - model_copy = copy.deepcopy(model) - - run_model(model, inputs, False) - ret2 = run_model(model_copy, inputs, True) - - # Make sure param hook has only be fired once in case of parameter sharing - assert ret2 == len(list(model.parameters())) - - for p, p_copy in zip(model.parameters(), model_copy.parameters()): - assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" + for p, p_copy in zip(model.parameters(), model_copy.parameters()): + assert allclose(p.grad, p_copy.grad), f"{p.grad} vs {p_copy.grad}" if __name__ == '__main__':