From 0bebda6ea501591d100cc151578482bf5d428e1c Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 25 Mar 2022 12:24:18 +0800 Subject: [PATCH] [zero] fix init device bug in zero init context unittest (#516) --- .../utils/memory_tracer/async_memtracer.py | 22 +++---------------- .../utils/memory_tracer/memstats_collector.py | 4 ++-- .../memory_tracer/model_data_memtracer.py | 3 +++ .../utils/memory_utils/memory_monitor.py | 22 +++++++++++++++++++ colossalai/utils/memory_utils/utils.py | 2 +- colossalai/zero/init_ctx/init_context.py | 15 +++++++------ .../zero/shard_utils/tensor_shard_strategy.py | 3 +++ .../test_init_context.py | 21 +++++++++++------- 8 files changed, 55 insertions(+), 37 deletions(-) diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py index 4091f94aa..74dd278c8 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -2,26 +2,10 @@ from concurrent.futures import ThreadPoolExecutor from time import sleep, time import pickle -from colossalai.utils import get_current_device import torch - -def get_cuda_memory_used(device: torch.device) -> int: - """ - Get the free memory info of device. - :param device: device id - :type device: torch.device - :return: current memory usage, sized by MB - :rtype: int - """ - - assert device.type == 'cuda' - - ret: int = torch.cuda.memory_allocated(device) - # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ - torch.cuda.reset_peak_memory_stats(device) - return ret +from colossalai.utils import get_current_device +from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used class AsyncMemoryMonitor: @@ -97,7 +81,7 @@ class AsyncMemoryMonitor: while self.keep_measuring: max_usage = max( max_usage, - get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')), + colo_cuda_memory_used(), ) sleep(self.interval) return max_usage diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index 9231cd5a0..b50888a19 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -1,5 +1,5 @@ from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from .async_memtracer import get_cuda_memory_used +from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used from colossalai.utils import get_current_device import torch @@ -55,7 +55,7 @@ class MemStatsCollector: sampling_cnt = self._sampling_cnter.sampling_cnt assert sampling_cnt == len(self._overall_cuda) self._model_data_cuda.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage) - self._overall_cuda.append(get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}'))) + self._overall_cuda.append(colo_cuda_memory_used(torch.device(f'cuda:{get_current_device()}'))) self._sampling_cnter.advance() def fetch_memstats(self) -> (int, int): diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index e5a742f64..e8cb9f7c6 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -44,6 +44,9 @@ class ModelDataTracer(metaclass=SingletonMeta): mem_use = _col_tensor_mem_usage(t) self._cuda_usage -= mem_use + def clear(self) -> None: + self._cuda_usage = 0 + @property def cpu_usage(self): return self._cpu_usage diff --git a/colossalai/utils/memory_utils/memory_monitor.py b/colossalai/utils/memory_utils/memory_monitor.py index eccfbc690..873e51ecd 100644 --- a/colossalai/utils/memory_utils/memory_monitor.py +++ b/colossalai/utils/memory_utils/memory_monitor.py @@ -9,6 +9,28 @@ import torch from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger +from colossalai.utils.cuda import get_current_device +from typing import Optional + + +def colo_cuda_memory_used(device: Optional[torch.device] = None) -> int: + """ + Get the free memory info of device. + :param device: a torch device instance or None + :type device: Optional[torch.device] + :return: current memory usage, sized by Byte + :rtype: int + """ + if device: + assert device.type == 'cuda' + else: + device = torch.device(f'cuda:{get_current_device()}') + + ret: int = torch.cuda.memory_allocated(device) + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats(device) + return ret def bytes_to_GB(val, decimal=2): diff --git a/colossalai/utils/memory_utils/utils.py b/colossalai/utils/memory_utils/utils.py index b1c24994c..d391c91f7 100644 --- a/colossalai/utils/memory_utils/utils.py +++ b/colossalai/utils/memory_utils/utils.py @@ -3,7 +3,7 @@ from colossalai.utils import get_current_device from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from typing import Union +from typing import Union, Optional _GLOBAL_CUDA_MEM_FRACTION = 1.0 diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index be73da796..9482bfe24 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -6,16 +6,14 @@ from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils.memory_tracer.model_data_memtracer import \ GLOBAL_MODEL_DATA_TRACER +from colossalai.utils.memory_utils.memory_monitor import colo_cuda_memory_used from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup from colossalai.logging import get_dist_logger, disable_existing_loggers -# Inserts _post_init_method at the end of init method - -# for all sub classes of torch.nn.Module class InsertPostInitMethodToModuleSubClasses(object): def __init__(self): @@ -144,8 +142,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): del self.initialized_param_list GLOBAL_MODEL_DATA_TRACER.close() - cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6 - self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0]) + model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6 + self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0]) + sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6 + self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0]) + self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0]) def _post_init_method(self, module: torch.nn.Module): """ @@ -178,8 +179,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) - if param.col_attr.sharded_data_tensor.device.type == 'cuda': - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) + if param.col_attr.sharded_data_tensor.device.type == 'cuda': + GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) # if param.col_attr.grad and self.shard_grad: # self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 9383889e9..7f2d2684e 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -23,6 +23,9 @@ class TensorShardStrategy(BaseShardStrategy): def _shard_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None): if t.is_sharded: return + if t.payload.device.type == 'cuda': + assert t.payload.device.index == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\ + f" but current cuda device is {get_current_device()}" sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group)) t.reset_payload(sharded_payload) t.is_sharded = True diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index f1b41ee09..84a8b63ff 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -19,17 +19,24 @@ from tests.components_to_test.registry import non_distributed_component_funcs from common import CONFIG -@parameterize("init_device", [torch.device('cpu'), torch.device(f'cuda:{get_current_device()}')]) +@parameterize("init_device_type", ['cpu', 'cuda']) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) -def run_model_test(init_device, shard_strategy_class): +def run_model_test(init_device_type, shard_strategy_class): for get_components_func in non_distributed_component_funcs: model_builder, _, _, _, _ = get_components_func() model_numel_tensor = torch.zeros(1, dtype=torch.int) + if init_device_type == 'cuda': + init_device = torch.device(f"cuda:{get_current_device()}") + elif init_device_type == 'cpu': + init_device = torch.device("cpu") + else: + continue with ZeroInitContext(convert_fp16=True, target_device=init_device, shard_strategy=shard_strategy_class(), shard_param=True, - model_numel_tensor=model_numel_tensor): + model_numel_tensor=model_numel_tensor, + rm_torch_payload_on_the_fly=False): model = model_builder(checkpoint=True) for param in model.parameters(): @@ -38,11 +45,9 @@ def run_model_test(init_device, shard_strategy_class): assert param.col_attr.sharded_data_tensor.is_sharded assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' - - print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}') - print(f'numel {model_numel_tensor}') - if init_device.type == 'cuda': - assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) + if init_device.type == 'cuda': + assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage > 0) + GLOBAL_MODEL_DATA_TRACER.clear() def run_dist(rank, world_size, port):