diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index eb6b8b128..20ceef71b 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -43,9 +43,10 @@ class ColoTensor(object): torch_tensor=tensor if save_payload else torch.empty(0)) return colo_t - def del_torch_tensor(self) -> None: - self._size = (0,) - self._torch_tensor = torch.empty(self._size) + def del_torch_tensor(self, save_shape=False) -> None: + if save_shape: + self._size = (0,) + self._torch_tensor = torch.empty((0,)) def torch_tensor(self) -> torch.Tensor: if self._torch_tensor.numel() == 0: diff --git a/colossalai/utils/__init__.py b/colossalai/utils/__init__.py index a35d3dcc3..15ef96de3 100644 --- a/colossalai/utils/__init__.py +++ b/colossalai/utils/__init__.py @@ -11,16 +11,47 @@ from .memory import (report_memory_usage, colo_device_memory_used, colo_set_proc colo_device_memory_capacity, colo_set_cpu_memory_capacity, colo_get_cpu_memory_capacity) from .timer import MultiTimer, Timer from .tensor_detector import TensorDetector -from .model.init_context import InsertPostInitMethodToModuleSubClasses +from .model.utils import InsertPostInitMethodToModuleSubClasses +from .model.colo_init_context import ColoInitContext __all__ = [ - 'checkpoint', 'free_port', 'print_rank_0', 'sync_model_param', 'is_dp_rank_0', 'is_tp_rank_0', - 'is_no_pp_or_last_stage', 'is_using_ddp', 'is_using_pp', 'is_using_sequence', 'conditional_context', - 'is_model_parallel_parameter', 'clip_grad_norm_fp32', 'count_zeros_fp32', 'copy_tensor_parallel_attributes', - 'param_is_not_tensor_parallel_duplicate', 'get_current_device', 'synchronize', 'empty_cache', 'set_to_cuda', - 'report_memory_usage', 'colo_device_memory_capacity', 'colo_device_memory_used', 'colo_set_process_memory_fraction', - 'Timer', 'MultiTimer', 'multi_tensor_applier', 'DataParallelSampler', 'get_dataloader', - 'switch_virtual_pipeline_parallel_rank', 'TensorDetector', 'load_checkpoint', 'save_checkpoint', - 'ensure_path_exists', 'disposable', 'colo_set_cpu_memory_capacity', 'colo_get_cpu_memory_capacity', - 'InsertPostInitMethodToModuleSubClasses' + 'checkpoint', + 'free_port', + 'print_rank_0', + 'sync_model_param', + 'is_dp_rank_0', + 'is_tp_rank_0', + 'is_no_pp_or_last_stage', + 'is_using_ddp', + 'is_using_pp', + 'is_using_sequence', + 'conditional_context', + 'is_model_parallel_parameter', + 'clip_grad_norm_fp32', + 'count_zeros_fp32', + 'copy_tensor_parallel_attributes', + 'param_is_not_tensor_parallel_duplicate', + 'get_current_device', + 'synchronize', + 'empty_cache', + 'set_to_cuda', + 'report_memory_usage', + 'colo_device_memory_capacity', + 'colo_device_memory_used', + 'colo_set_process_memory_fraction', + 'Timer', + 'MultiTimer', + 'multi_tensor_applier', + 'DataParallelSampler', + 'get_dataloader', + 'switch_virtual_pipeline_parallel_rank', + 'TensorDetector', + 'load_checkpoint', + 'save_checkpoint', + 'ensure_path_exists', + 'disposable', + 'colo_set_cpu_memory_capacity', + 'colo_get_cpu_memory_capacity', + 'InsertPostInitMethodToModuleSubClasses', + 'ColoInitContext', ] diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py new file mode 100644 index 000000000..d32cc58d4 --- /dev/null +++ b/colossalai/utils/model/colo_init_context.py @@ -0,0 +1,40 @@ +from .utils import InsertPostInitMethodToModuleSubClasses +import torch +# from colossalai.logging import get_dist_logger +from colossalai.tensor import ColoTensor + +# _orig_torch_empty = torch.empty + + +class ColoInitContext(InsertPostInitMethodToModuleSubClasses): + + def __init__(self, lazy_memory_allocate=False): + super().__init__() + self._lazy_memory_allocate = lazy_memory_allocate + + def _pre_context_exec(self): + """ + The Callback function when entering the context + """ + pass + + def _post_context_exec(self): + """The callback function when exiting context. + """ + pass + + def _post_init_method(self, module: torch.nn.Module): + """ + The function to call at the end of the constructor of each module. + FIXME(fjr) The module may be passed to this function multiple times? + """ + name_list = [] + for name, param in module.named_parameters(): + if isinstance(param, ColoTensor): + continue + name_list.append((name, param)) + + save_torch_payload = True if not self._lazy_memory_allocate else False + for name, param in name_list: + delattr(module, name) + setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param.data, save_payload=save_torch_payload)) diff --git a/colossalai/utils/model/init_context.py b/colossalai/utils/model/utils.py similarity index 100% rename from colossalai/utils/model/init_context.py rename to colossalai/utils/model/utils.py diff --git a/tests/test_tensor/test_context.py b/tests/test_tensor/test_context.py new file mode 100644 index 000000000..3c998fa66 --- /dev/null +++ b/tests/test_tensor/test_context.py @@ -0,0 +1,27 @@ +from colossalai.utils import ColoInitContext + +from numpy import allclose, require +import torch +from colossalai.tensor import ColoTensor +from copy import deepcopy + + +def test_linear(): + in_dim = 4 + out_dim = 5 + + with ColoInitContext(lazy_memory_allocate=True) as ctx: + fc = torch.nn.Linear(in_dim, out_dim, bias=True) + + print(fc.weight.numel()) + print(fc.bias.numel()) + + # lazy_memory_allocate=True, no payload is maintained + assert fc.weight._torch_tensor.numel() == 0 + + fc.weight.torch_tensor() + assert fc.weight._torch_tensor.numel() == in_dim * out_dim + + +if __name__ == '__main__': + test_linear()