diff --git a/colossalai/engine/ophooks/__init__.py b/colossalai/engine/ophooks/__init__.py index c2664b2af..9e81ba56d 100644 --- a/colossalai/engine/ophooks/__init__.py +++ b/colossalai/engine/ophooks/__init__.py @@ -1,3 +1,4 @@ from .utils import register_ophooks_recursively, BaseOpHook +from ._memtracer_ophook import MemTracerOpHook __all__ = ["BaseOpHook", "MemTracerOpHook", "register_ophooks_recursively"] diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index 2aa32b829..de2b0435c 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -1,6 +1,6 @@ from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory import colo_device_memory_used -from colossalai.utils.memory_tracer import AsyncMemoryMonitor +from colossalai.utils.memory_tracer import SyncCudaMemoryMonitor import torch import time from typing import List @@ -19,7 +19,7 @@ class MemStatsCollector: """ def __init__(self) -> None: - self._mem_monitor = AsyncMemoryMonitor() + self._mem_monitor = SyncCudaMemoryMonitor() self._model_data_cuda_list = [] self._overall_cuda_list = [] @@ -31,9 +31,10 @@ class MemStatsCollector: self._sampling_time = [] self._start_flag = False - self._period_idx = 0 + self._step_idx = 0 + self._step_total = 0 - def overall_mem_stats(self, device_type: str): + def overall_mem_stats(self, device_type: str) -> List[int]: if device_type == 'cuda': return self._overall_cuda_list elif device_type == 'cpu': @@ -41,47 +42,23 @@ class MemStatsCollector: else: raise TypeError - def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]: - if unit == 'GB': - scale = 1e9 - elif unit == 'MB': - scale = 1e6 - elif unit == 'KB': - scale = 1e3 - elif unit == 'B': - scale = 1 - else: - raise TypeError - + def model_data_list(self, device_type: str) -> List[int]: if device_type == 'cuda': - return [elem / scale for elem in self._model_data_cuda_list] + return self._model_data_cuda_list elif device_type == 'cpu': - return [elem / scale for elem in self._model_data_cpu_list] - else: - raise TypeError - - def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]: - """Non model data stats - """ - if unit == 'GB': - scale = 1e9 - elif unit == 'MB': - scale = 1e6 - elif unit == 'KB': - scale = 1e3 - elif unit == 'B': - scale = 1 + return self._model_data_cpu_list else: raise TypeError + def non_model_data_list(self, device_type: str) -> List[int]: if device_type == 'cuda': - return [elem / scale for elem in self._non_model_data_cuda_list] + return self._non_model_data_cuda_list elif device_type == 'cpu': - return [elem / scale for elem in self._non_model_data_cpu_list] + return self._non_model_data_cpu_list else: raise TypeError - def max_non_model_data(self, device_type: str) -> int: + def next_period_non_model_data_usage(self, device_type: str) -> int: """Get max non model data memory usage of current sampling period Args: @@ -91,12 +68,10 @@ class MemStatsCollector: int: max non model data memory usage of current sampling period """ assert not self._start_flag, 'Cannot get mem stats info during collection phase.' - assert len(self._sampling_time) > 0, 'Cannot get mem stats info before collection phase.' - next_period_idx = (self._period_idx + 1) % len(self._sampling_time) - current_non_model_data = self.non_model_data_list(device_type)[self._period_idx] - next_non_model_data = self.non_model_data_list(device_type)[next_period_idx] - self._period_idx = next_period_idx - return max(current_non_model_data, next_non_model_data) + assert self._step_total > 0, 'Cannot get mem stats info before collection phase.' + next_non_model_data = self.non_model_data_list(device_type)[self._step_idx] + self._step_idx = (self._step_idx + 1) % self._step_total + return next_non_model_data @property def sampling_time(self): @@ -107,9 +82,37 @@ class MemStatsCollector: self._mem_monitor.start() def finish_collection(self): + self.sample_overall_data() + self._step_total = len(self._sampling_time) self._start_flag = False self._mem_monitor.finish() + def sample_model_data(self) -> None: + """Sampling model data statistics. + """ + if self._start_flag: + cuda_mem, cpu_mem = GLOBAL_MODEL_DATA_TRACER.both_mem_usage + self._model_data_cuda_list.append(cuda_mem) + self._model_data_cpu_list.append(cpu_mem) + + def sample_overall_data(self) -> None: + """Sampling non model data statistics. + """ + if self._start_flag: + # overall data recording is after model data recording + if len(self._model_data_cuda_list) == 0: + return + + self._overall_cuda_list.append(self._mem_monitor.finish()) + self._overall_cpu_list.append(colo_device_memory_used(torch.device('cpu'))) + + assert len(self._model_data_cuda_list) == len(self._overall_cuda_list) + + self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) + self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1]) + self._sampling_time.append(time.time()) + self._mem_monitor.start() + def sample_memstats(self) -> None: """ Sampling memory statistics. @@ -119,7 +122,7 @@ class MemStatsCollector: if self._start_flag: self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage) self._overall_cuda_list.append(self._mem_monitor.finish()) - self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1]) + self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1]) self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage) # FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor() @@ -136,4 +139,5 @@ class MemStatsCollector: self._overall_cpu_list = [] self._start_flag = False - self._period_idx = 0 + self._step_idx = 0 + self._step_total = 0 diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index 4c9e4f804..98228892d 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -101,5 +101,9 @@ class ModelDataTracer(metaclass=SingletonMeta): cuda_usage, _ = self._get_mem_usage() return cuda_usage + @property + def both_mem_usage(self): + return self._get_mem_usage() + GLOBAL_MODEL_DATA_TRACER = ModelDataTracer() diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 87a09df44..d7fa64476 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -109,6 +109,5 @@ class ShardedParamV2(object): if self.param.grad is not None and self.param.grad.data_ptr() not in address_set: _update_mem_use(self.param.grad) - address_set.add(self.param.grad.data_ptr()) return cuda_mem_use, cpu_mem_use diff --git a/colossalai/zero/sharded_param/tensor_utils.py b/colossalai/zero/sharded_param/tensor_utils.py index 8282a4f86..4895becaf 100644 --- a/colossalai/zero/sharded_param/tensor_utils.py +++ b/colossalai/zero/sharded_param/tensor_utils.py @@ -13,7 +13,7 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[ cuda_use, cpu_use = 0, 0 - mem_use = t.numel() * t.element_size() + mem_use = t.storage().size() * t.element_size() if t.device.type == 'cuda': cuda_use += mem_use elif t.device.type == 'cpu': diff --git a/colossalai/zero/utils/stateful_tensor_mgr.py b/colossalai/zero/utils/stateful_tensor_mgr.py index c06dcc4a3..107d14f5b 100644 --- a/colossalai/zero/utils/stateful_tensor_mgr.py +++ b/colossalai/zero/utils/stateful_tensor_mgr.py @@ -38,10 +38,6 @@ class StatefulTensorMgr(object): def adjust_layout(self) -> None: """ Adjust the layout of statefuil tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. - - Args: - mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model. - It contains non-model footprint of a DNN model. """ # find stateful tensor in state COMPUTE cuda_demand = 0 diff --git a/colossalai/zero/utils/tensor_placement_policy.py b/colossalai/zero/utils/tensor_placement_policy.py index d7a977188..d74da56c0 100644 --- a/colossalai/zero/utils/tensor_placement_policy.py +++ b/colossalai/zero/utils/tensor_placement_policy.py @@ -61,7 +61,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio else: # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. - max_cuda_non_model_data_per_period = self.mem_stats_collector.max_non_model_data('cuda') + max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda') total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data if avail_cuda_model_data < cuda_demand: @@ -71,7 +71,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy): freed_cuda_model_data = 0 to_free_tensor_list = hold_cuda_tensor_list if not warmup: - next_compute_idx: Dict[StatefulTensor, int] = {t: len(compute_list) for t in hold_cuda_tensor_list} + next_compute_idx = {t: len(compute_list) for t in hold_cuda_tensor_list} for i in range(len(compute_list) - 1, compute_idx, -1): if compute_list[i] in next_compute_idx: next_compute_idx[compute_list[i]] = i diff --git a/colossalai/zero/utils/zero_hook.py b/colossalai/zero/utils/zero_hook.py index 34d7e0d5a..40b44fc12 100644 --- a/colossalai/zero/utils/zero_hook.py +++ b/colossalai/zero/utils/zero_hook.py @@ -36,17 +36,7 @@ class ZeroHook(BaseOpHook): self._memstarts_collector = memstarts_collector self._stateful_tensor_mgr = stateful_tensor_mgr - def pre_fwd_exec(self, module: torch.nn.Module, *args): - - for param in module.parameters(recurse=False): - param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) - - if self._stateful_tensor_mgr: - self._stateful_tensor_mgr.adjust_layout() - else: - for param in module.parameters(recurse=False): - colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) - + def gather_parameters(self, module: torch.nn.Module): # gather sharded parameters if module.param_is_sharded: tensor_list = [] @@ -55,10 +45,33 @@ class ZeroHook(BaseOpHook): tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) - # record memory statistics - if self._memstarts_collector: - self._memstarts_collector.sample_memstats() + def shard_parameters(self, module: torch.nn.Module): + # shard gathered parameters + if module.param_is_sharded: + tensor_list = [] + for param in module.parameters(recurse=False): + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) + self.shard_strategy.shard(tensor_list, self.process_group) + def adjust_module_data(self, module: torch.nn.Module): + # record overall data statistics + if self._memstarts_collector: + self._memstarts_collector.sample_overall_data() + + for param in module.parameters(recurse=False): + param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + + # adjust stateful tensor to get enough CUDA memory + self._stateful_tensor_mgr.adjust_layout() + + # record model data statistics + if self._memstarts_collector: + self._memstarts_collector.sample_model_data() + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + self.adjust_module_data(module) + self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA" @@ -69,41 +82,15 @@ class ZeroHook(BaseOpHook): for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) - # shard gathered parameters - if module.param_is_sharded: - tensor_list = [] - for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') - tensor_list.append(param.colo_attr.sharded_data_tensor) - self.shard_strategy.shard(tensor_list, self.process_group) + self.shard_parameters(module) # remove torch payload for param in module.parameters(recurse=False): param.colo_attr.set_data_none() def pre_bwd_exec(self, module: torch.nn.Module, input, output): - - for param in module.parameters(recurse=False): - param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) - - if self._stateful_tensor_mgr: - self._stateful_tensor_mgr.adjust_layout() - else: - for param in module.parameters(recurse=False): - colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) - - # gather sharded parameters - if module.param_is_sharded: - tensor_list = [] - for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') - tensor_list.append(param.colo_attr.sharded_data_tensor) - self.shard_strategy.gather(tensor_list, self.process_group) - - # record memory statistics - if self._memstarts_collector: - self._memstarts_collector.sample_memstats() - + self.adjust_module_data(module) + self.gather_parameters(module) for param in module.parameters(recurse=False): param.data = param.colo_attr.data_payload assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA" @@ -114,13 +101,7 @@ class ZeroHook(BaseOpHook): for param in module.parameters(recurse=False): param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) - # shard gathered parameters - if module.param_is_sharded: - tensor_list = [] - for param in module.parameters(recurse=False): - assert hasattr(param, 'colo_attr') - tensor_list.append(param.colo_attr.sharded_data_tensor) - self.shard_strategy.shard(tensor_list, self.process_group) + self.shard_parameters(module) # remove torch payload for param in module.parameters(recurse=False): diff --git a/tests/test_zero/test_mem_collector.py b/tests/test_zero/test_mem_collector.py new file mode 100644 index 000000000..62c367701 --- /dev/null +++ b/tests/test_zero/test_mem_collector.py @@ -0,0 +1,74 @@ +import torch +import colossalai +import pytest +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +from colossalai.utils.cuda import get_current_device +from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction +from colossalai.zero.init_ctx import ZeroInitContext +from colossalai.zero.sharded_model import ShardedModelV2 +from colossalai.zero.shard_utils import BucketTensorShardStrategy +from colossalai.utils import free_port +from colossalai.testing import rerun_on_exception +from functools import partial + + +class TestModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.proj1 = nn.Linear(512, 512) + self.weight = nn.Parameter(torch.randn(1024, 512)) + self.proj2 = nn.Linear(1024, 512) + + def forward(self, x): + x = self.proj1(x) + x = F.linear(x, self.weight) + x = self.proj2(x) + + return x + + +def run_mem_collector_testing(): + cuda_capacity = colo_device_memory_capacity(get_current_device()) + fraction = (50 * 1024**2) / cuda_capacity + # limit max memory to 50MB + colo_set_process_memory_fraction(fraction) + shard_strategy = BucketTensorShardStrategy() + with ZeroInitContext(target_device=get_current_device(), shard_strategy=shard_strategy, shard_param=True): + model = TestModel() + + model = ShardedModelV2(module=model, + shard_strategy=shard_strategy, + reduce_scatter_bucket_size_mb=1, + tensor_placement_policy='auto') + + data = torch.randn(2, 512, device=get_current_device()) + + output = model(data) + loss = torch.mean(output) + model.backward(loss) + + cuda_model_data_list = model._memstats_collector.model_data_list('cuda') + assert cuda_model_data_list == [1311744, 1836032, 1836032, 1311744, 1836032, 1836032] + + cuda_non_model_data_list = model._memstats_collector.non_model_data_list('cuda') + assert cuda_non_model_data_list[0] > cuda_non_model_data_list[1] + assert cuda_non_model_data_list[-2] > cuda_non_model_data_list[-1] + + +def run_dist(rank, world_size, port): + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + run_mem_collector_testing() + + +@pytest.mark.dist +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_mem_collector(world_size=2): + run_func = partial(run_dist, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + +if __name__ == '__main__': + test_mem_collector() diff --git a/tests/test_zero/test_stateful_tensor_mgr.py b/tests/test_zero/test_stateful_tensor_mgr.py index 5b9e35a26..ebec0bcfd 100644 --- a/tests/test_zero/test_stateful_tensor_mgr.py +++ b/tests/test_zero/test_stateful_tensor_mgr.py @@ -48,30 +48,39 @@ def run_stm(): # warmup # use naive eviction strategy apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() mem_collector.finish_collection() stateful_tensor_mgr.reset() # warmup done # use OPT-like eviction strategy apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.sample_overall_data() apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr) - mem_collector.sample_memstats() + mem_collector.sample_model_data() + mem_collector.finish_collection() def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],