[zero] stateful tensor manager (#687)

* [WIP] stateful tensor manager

* add eviction strategy

* polish code

* polish code

* polish comment

* add unit test

* fix sampler bug

* polish code

* fix max sampling cnt resetting bug

* fix sampler bug

* polish code

* fix bug

* fix unit test

Co-authored-by: jiaruifang <fangjiarui123@gmail.com>
This commit is contained in:
ver217
2022-04-08 17:51:34 +08:00
committed by GitHub
parent 70e8dd418b
commit 3c9cd5bb5e
8 changed files with 271 additions and 73 deletions

View File

@@ -20,10 +20,12 @@ class SamplingCounter:
assert self._max_sampling_cnt is not None
return (self._samplint_cnt + 1) % self._max_sampling_cnt
@property
def sampling_cnt(self):
def current(self):
return self._samplint_cnt
def max(self):
return self._max_sampling_cnt
def reset(self):
self._max_sampling_cnt = self._samplint_cnt
self._samplint_cnt = 0
@@ -37,7 +39,7 @@ class MemStatsCollector:
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
It has a Sampling counter which is reset after DNN training iteration.
"""
@@ -50,6 +52,8 @@ class MemStatsCollector:
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = []
self._start_flag = False
@@ -96,18 +100,20 @@ class MemStatsCollector:
raise TypeError
if device_type == 'cuda':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
return [elem / scale for elem in self._non_model_data_cuda_list]
elif device_type == 'cpu':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)]
return [elem / scale for elem in self._non_model_data_cpu_list]
else:
raise TypeError
def current_non_model_data(self, device_type: str) -> int:
"""get the non model data of current sampling moment
"""get the non model data of the current sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt]
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
def next_non_model_data(self, device_type: str):
"""get the non model data of the next sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
@property
@@ -128,18 +134,20 @@ class MemStatsCollector:
Advance the sampling cnter.
"""
if self._start_flag:
sampling_cnt = self._sampling_cnter.sampling_cnt
sampling_cnt = self._sampling_cnter.current()
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
# FIXME() cpu sys used should also return from self._mem_monitor()
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time())
self._mem_monitor.start()
# TODO(ver217): refactor sampler
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None:
@@ -155,4 +163,4 @@ class MemStatsCollector:
self._start_flag = False
self._sampling_cnter.reset()
self._mem_monitor.finish()
self._mem_monitor.finish()