From 920c5889a732bfd3752b243eb26111a913f747df Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 25 Mar 2022 14:02:55 +0800 Subject: [PATCH] [zero] add colo move inline (#521) --- colossalai/engine/ophooks/zero_hook.py | 9 +++----- colossalai/utils/memory_utils/utils.py | 28 ++++++++++++++++++++++++ colossalai/zero/init_ctx/init_context.py | 2 +- tests/test_utils/test_commons.py | 8 ++++++- 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index 913f82ed7..56c32d453 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -10,6 +10,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.zero.shard_utils import BaseShardStrategy from ._base_ophook import BaseOpHook +from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline @OPHOOKS.register_module @@ -37,9 +38,7 @@ class ZeroHook(BaseOpHook): tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(): - if param.col_attr.sharded_data_tensor.device != self.computing_device: - param.col_attr.sharded_data_tensor.to(self.computing_device) - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) + colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) param.data = param.col_attr.sharded_data_tensor.payload if self._memstarts_collector: @@ -61,9 +60,7 @@ class ZeroHook(BaseOpHook): tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(): - if param.col_attr.sharded_data_tensor.device != self.computing_device: - param.col_attr.sharded_data_tensor.to(self.computing_device) - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) + colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) param.data = param.col_attr.sharded_data_tensor.payload # Store local accumulated grad shard if param.grad is not None: diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py index d391c91f7..52bb487d0 100644 --- a/colossalai/utils/memory_utils/utils.py +++ b/colossalai/utils/memory_utils/utils.py @@ -65,6 +65,34 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype) +def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None: + """ + move a tensor to the target_device + Args: + t (Union[ShardedTensor, torch.Tensor]): the tensor be moved + """ + + if isinstance(t, ShardedTensor): + t_payload = t.payload + elif isinstance(t, torch.Tensor): + t_payload = t + else: + raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}') + + assert isinstance(target_device, torch.device) + + # deal with torch.device('cpu') and torch.device('cpu:0) + if t_payload.device.type == target_device.type: + return + + if target_device.type == 'cuda': + GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload) + elif target_device.type == 'cpu': + GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload) + + t_payload.data = t_payload.data.to(target_device) + + def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None: """colo_model_data_move_to_cpu diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 9482bfe24..bd765f1a6 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -143,7 +143,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): del self.initialized_param_list GLOBAL_MODEL_DATA_TRACER.close() model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6 - self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) + self.logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6 self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0]) self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0]) diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index 9b42b10f9..dad0cacfc 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -1,5 +1,5 @@ from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move +from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline from colossalai.utils import free_port from colossalai.zero.sharded_param import ShardedTensor @@ -40,6 +40,12 @@ def run_tensor_move(rank): colo_model_data_tensor_move(src_t, tgt_t) assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}" assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + + assert (tgt_t.device.type == 'cuda') + colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu')) + assert (tgt_t.device.type == 'cpu') + assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}" + GLOBAL_MODEL_DATA_TRACER.close()