From dfc88b85ea3cecfe7e12e64c091fc94fe39e81a8 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Wed, 11 May 2022 10:54:19 +0800 Subject: [PATCH] [Tensor] simplify named param (#928) * simplify ColoModulize * simplify ColoModulize * polish * polish --- colossalai/utils/model/colo_init_context.py | 37 +++------------------ tests/test_tensor/test_model.py | 12 +++++-- 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 877c8428c..7aed1d471 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -90,56 +90,28 @@ def ColoModulize(module): Replacing the parameters() and named_parameters() with our customized ones """ - def named_params_with_colotensor( - module: nn.Module, - prefix: str = '', - recurse: bool = True, - ) -> Iterator[Tuple[str, Union[nn.Parameter, ColoTensor]]]: - modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] - - memo = set() - for mod_prefix, mod in modules: - # find all colotensors tensor params - for name, val in vars(mod).items(): - if isinstance(val, ColoTensor) and val not in memo: - memo.add(val) - name = mod_prefix + ('.' if mod_prefix else '') + name - yield name, val - - # find all nn.Parameters - for name, val in module.old_named_parameters(recurse=recurse): - yield name, val - def fake_parameters(self, *args, **kargs): - for name, p in named_params_with_colotensor(self, *args, **kargs): + for p in module.old_parameters(*args, **kargs): if isinstance(p, ColoTensor): yield p.torch_tensor() elif isinstance(p, torch.Tensor): yield p def fake_named_parameters(self, *args, **kargs): - for name, p in named_params_with_colotensor(self, *args, **kargs): + for name, p in module.old_named_parameters(*args, **kargs): if isinstance(p, ColoTensor): yield name, p.torch_tensor() elif isinstance(p, torch.Tensor): yield name, p - def colo_parameters(self, *args, **kargs): - for _, p in named_params_with_colotensor(self, *args, **kargs): - yield p - - def colo_named_parameters(self, *args, **kargs): - for name, p in named_params_with_colotensor(self, *args, **kargs): - yield name, p - module.old_named_parameters = module.named_parameters module.old_parameters = module.parameters funcType = types.MethodType module.parameters = funcType(fake_parameters, module) module.named_parameters = funcType(fake_named_parameters, module) - module.colo_parameters = funcType(colo_parameters, module) - module.colo_named_parameters = funcType(colo_named_parameters, module) + module.colo_parameters = module.old_parameters + module.colo_named_parameters = module.old_named_parameters module._colo_visited = True class ColoInitContext(InsertPostInitMethodToModuleSubClasses): @@ -154,7 +126,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): self._lazy_memory_allocate = lazy_memory_allocate self._device = device - # TODO(jzy) replace it with old __setattr__ in the exit() of context? torch.nn.Module.__setattr__ = _setattr_with_colotensor torch.nn.Module.register_parameter = _register_parameter_with_colotensor diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 8f8fb4597..7fe850af1 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -1,3 +1,4 @@ +from colossalai.tensor.colo_parameter import ColoParameter from tests.components_to_test.registry import non_distributed_component_funcs import colossalai @@ -371,7 +372,13 @@ def _run_pretrain_load(): dict_col = {} for name, param in model_pretrained.named_parameters(): dict_pretrained[name] = param - for name, param in model.named_parameters(): + c1 = 0 + c2 = 0 + for name, param in model.colo_named_parameters(): + if isinstance(param, ColoParameter): + c1 = c1 + 1 + else: + c2 = c2 + 1 dict_col[name] = param for name, param in dict_pretrained.items(): @@ -416,4 +423,5 @@ if __name__ == '__main__': # test_model_parameters() # test_colo_optimizer() # test_model() - _test_pretrain_load(4) + # _test_pretrain_load(4) + _run_pretrain_load()