mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-15 22:19:38 +00:00
[zero] stateful tensor manager (#687)
* [WIP] stateful tensor manager * add eviction strategy * polish code * polish code * polish comment * add unit test * fix sampler bug * polish code * fix max sampling cnt resetting bug * fix sampler bug * polish code * fix bug * fix unit test Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
@@ -20,10 +20,12 @@ class SamplingCounter:
|
||||
assert self._max_sampling_cnt is not None
|
||||
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
||||
|
||||
@property
|
||||
def sampling_cnt(self):
|
||||
def current(self):
|
||||
return self._samplint_cnt
|
||||
|
||||
def max(self):
|
||||
return self._max_sampling_cnt
|
||||
|
||||
def reset(self):
|
||||
self._max_sampling_cnt = self._samplint_cnt
|
||||
self._samplint_cnt = 0
|
||||
@@ -37,7 +39,7 @@ class MemStatsCollector:
|
||||
The first iteration of DNN training.
|
||||
Phase 2. Runtime Phase: use the read-only collected stats
|
||||
The rest iterations of DNN training.
|
||||
|
||||
|
||||
It has a Sampling counter which is reset after DNN training iteration.
|
||||
"""
|
||||
|
||||
@@ -50,6 +52,8 @@ class MemStatsCollector:
|
||||
self._model_data_cpu_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
self._sampling_time = []
|
||||
|
||||
self._start_flag = False
|
||||
@@ -96,18 +100,20 @@ class MemStatsCollector:
|
||||
raise TypeError
|
||||
|
||||
if device_type == 'cuda':
|
||||
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
|
||||
return [elem / scale for elem in self._non_model_data_cuda_list]
|
||||
elif device_type == 'cpu':
|
||||
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)]
|
||||
return [elem / scale for elem in self._non_model_data_cpu_list]
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def current_non_model_data(self, device_type: str) -> int:
|
||||
"""get the non model data of current sampling moment
|
||||
"""get the non model data of the current sampling moment
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt]
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
|
||||
|
||||
def next_non_model_data(self, device_type: str):
|
||||
"""get the non model data of the next sampling moment
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
||||
|
||||
@property
|
||||
@@ -128,18 +134,20 @@ class MemStatsCollector:
|
||||
Advance the sampling cnter.
|
||||
"""
|
||||
if self._start_flag:
|
||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||
sampling_cnt = self._sampling_cnter.current()
|
||||
assert sampling_cnt == len(self._overall_cuda_list)
|
||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
||||
|
||||
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
||||
|
||||
# FIXME() cpu sys used should also return from self._mem_monitor()
|
||||
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
|
||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
||||
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
self._sampling_time.append(time.time())
|
||||
self._mem_monitor.start()
|
||||
# TODO(ver217): refactor sampler
|
||||
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def reset_sampling_cnter(self) -> None:
|
||||
@@ -155,4 +163,4 @@ class MemStatsCollector:
|
||||
|
||||
self._start_flag = False
|
||||
self._sampling_cnter.reset()
|
||||
self._mem_monitor.finish()
|
||||
self._mem_monitor.finish()
|
||||
|
Reference in New Issue
Block a user