diff --git a/colossalai/utils/memory_tracer/__init__.py b/colossalai/utils/memory_tracer/__init__.py index f40430d38..3f4cd66b8 100644 --- a/colossalai/utils/memory_tracer/__init__.py +++ b/colossalai/utils/memory_tracer/__init__.py @@ -1,3 +1,4 @@ from .async_memtracer import AsyncMemoryMonitor +from .memstats_collector import MemStatsCollector -__all__ = ['AsyncMemoryMonitor'] +__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector'] diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index 9f69f5dde..054d3f282 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -11,15 +11,21 @@ class SamplingCounter: def __init__(self) -> None: self._samplint_cnt = 0 + self._max_sampling_cnt = None def advance(self): self._samplint_cnt += 1 + def next(self): + assert self._max_sampling_cnt is not None + return (self._samplint_cnt + 1) % self._max_sampling_cnt + @property def sampling_cnt(self): return self._samplint_cnt def reset(self): + self._max_sampling_cnt = self._samplint_cnt self._samplint_cnt = 0 @@ -56,7 +62,7 @@ class MemStatsCollector: else: raise TypeError - def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]: + def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]: if unit == 'GB': scale = 1e9 elif unit == 'MB': @@ -75,7 +81,7 @@ class MemStatsCollector: else: raise TypeError - def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]: + def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]: """Non model data stats """ if unit == 'GB': @@ -96,6 +102,14 @@ class MemStatsCollector: else: raise TypeError + def current_non_model_data(self, device_type: str) -> int: + """get the non model data of current sampling moment + """ + return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt] + + def next_non_model_data(self, device_type: str): + return self.non_model_data_list(device_type)[self._sampling_cnter.next()] + @property def sampling_time(self): return [t - self._sampling_time[0] for t in self._sampling_time] diff --git a/colossalai/zero/shard_utils/stateful_tensor_mgr.py b/colossalai/zero/shard_utils/stateful_tensor_mgr.py new file mode 100644 index 000000000..8daefeb2f --- /dev/null +++ b/colossalai/zero/shard_utils/stateful_tensor_mgr.py @@ -0,0 +1,69 @@ +import torch +from colossalai.context.singleton_meta import SingletonMeta +from colossalai.utils.cuda import get_current_device +from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 +from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState +from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage +from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity +from typing import Set +from colossalai.utils.memory_tracer import MemStatsCollector + + +class StatefulTensorMgr(SingletonMeta): + _stateful_tensor_list: Set[ShardedParamV2] = set() + + def register_param(self, param: ShardedParamV2) -> None: + for t in param.get_payload_tensors(): + assert isinstance(t, StatefulTensor) + self._stateful_tensor_list.add(t) + + def evict_tensors(self) -> None: + pass + + def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None: + """ Adjust the layout of statefuil tensor according to the information provided + by mem_stats_collector, which should belongs to a Sharded Model. + + Args: + mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model. + It contains non-model footprint of a DNN model. + """ + # find stateful tensor in state COMPUTE + move_to_cuda_tensor_list = [] + cuda_demand = 0 + used_cuda_model_data = 0 + hold_cuda_tensor_list = [] + for tensor in self._stateful_tensor_list: + if tensor.state == TensorState.FREE: + continue + + if tensor.device.type == 'cuda': + used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0] + if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: + hold_cuda_tensor_list.append(tensor) + else: + if tensor.state == TensorState.COMPUTE: + move_to_cuda_tensor_list.append(tensor) + cuda_demand += colo_tensor_mem_usage(tensor.payload)[0] + + # max non-model-data cuda memory consumption of this sampling moment and the next sampling moment. + max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'), + mem_stats_collector.next_non_model_data('cuda')) + cuda_capacity = colo_cuda_memory_capacity() + cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period + if cuda_model_data_period < used_cuda_model_data + cuda_demand: + # move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor + # Here use a naive eviction strategy. + acc_size = 0 + for t in hold_cuda_tensor_list: + if acc_size > cuda_demand: + break + colo_model_data_tensor_move_inline(t, torch.device('cpu')) + t_size = colo_tensor_mem_usage(t) + acc_size += t_size + if acc_size < cuda_demand: + raise RuntimeError("Adjust layout failed! No enough CUDA memory!") + + # move COMPUTE tensors to CUDA + for t in move_to_cuda_tensor_list: + colo_model_data_tensor_move_inline(t, get_current_device()) diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index ec934213b..277eab380 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -3,6 +3,7 @@ from colossalai.zero.sharded_param import ShardedTensor from typing import Optional, Tuple from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage from .tensorful_state import StatefulTensor, TensorState +from typing import List class ShardedParamV2(object): @@ -22,6 +23,11 @@ class ShardedParamV2(object): if rm_torch_payload: self.remove_torch_payload() + def get_payload_tensors(self) -> List[StatefulTensor]: + """returns stateful tensors kept by this class. + """ + return [self._sharded_data_tensor, self.saved_grad] + def remove_torch_payload(self): self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)