[Gemini] update API of the chunkmemstatscollector. (#2129)

This commit is contained in:
Jiarui Fang
2022-12-14 00:47:06 +08:00
committed by GitHub
parent 2938edf446
commit c89c66a858
8 changed files with 32 additions and 163 deletions

View File

@@ -55,7 +55,7 @@ class GeminiManager:
get the memory statistics during training.
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
Note, for the latter you can not access the memstats before warmup iteration finishes.
Note, for the latter, you can not access the memstats before warmup iteration finishes.
"""
if self._premade_memstats_:
return self._memstats

View File

@@ -11,18 +11,25 @@ from .memstats_collector import MemStatsCollector
class ChunkMemStatsCollector(MemStatsCollector):
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
"""
Memory Statistic Collector for Chunks.
Args:
chunk_manager (ChunkManager): the chunk manager.
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
"""
super().__init__(memstats)
self._chunk_manager = chunk_manager
# override
def record_model_data_volume(self) -> None:
"""Sampling model data statistics.
"""
record model data volumn on cuda and cpu.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = self._chunk_manager.total_mem['cuda']
cpu_mem = self._chunk_manager.total_mem['cpu']
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
self._memstats.record_max_cuda_model_data(cuda_mem)
@property
def cuda_margin_mem(self) -> float:

View File

@@ -22,6 +22,7 @@ class MemStats(object):
self._preop_step = 0
self._prev_overall_cuda = -1
self._max_overall_cuda = 0
self._prev_md_cuda = -1
# old version
@@ -46,6 +47,11 @@ class MemStats(object):
def record_max_cuda_overall_data(self, val):
self._prev_overall_cuda = val
self._max_overall_cuda = max(self._max_overall_cuda, val)
@property
def max_overall_cuda(self):
return self._max_overall_cuda
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
"""
@@ -85,67 +91,6 @@ class MemStats(object):
else:
return self._param_runtime_order
## APIs to be depracated
def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._overall_cuda_list.append(val)
elif device_type == 'cpu':
self._overall_cpu_list.append(val)
else:
raise TypeError
def append_model_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._model_data_cuda_list.append(val)
elif device_type == 'cpu':
self._model_data_cpu_list.append(val)
else:
raise TypeError
def last_model_data(self, device_type: str):
if len(self._model_data_cuda_list) == 0:
return None
if device_type == 'cuda':
return self._model_data_cuda_list[-1]
elif device_type == 'cpu':
return self._model_data_cpu_list[-1]
else:
raise TypeError
def append_non_model_data(self, device_type: str, val=None):
if device_type == 'cuda':
if val is None:
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
return
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
else:
self._non_model_data_cuda_list.append(val)
elif device_type == 'cpu':
if val is None:
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
return
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
else:
self._non_model_data_cuda_list.append(val)
else:
raise TypeError
def overall_mem_stats(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._model_data_cuda_list
elif device_type == 'cpu':
return self._model_data_cpu_list
else:
raise TypeError
def non_model_data_list(self, device_type: str) -> List[int]:
if device_type == 'cuda':
return self._non_model_data_cuda_list

View File

@@ -59,6 +59,7 @@ class MemStatsCollector:
return [t - self._sampling_time[0] for t in self._sampling_time]
def start_collection(self):
print('start collection')
self._start_flag = True
self._mem_monitor.start()
@@ -68,31 +69,24 @@ class MemStatsCollector:
self._step_total = len(self._memstats.non_model_data_list('cuda'))
self._start_flag = False
self._mem_monitor.finish()
print(f'finish_collection {self._step_total}')
# deprecated
def record_model_data_volume(self) -> None:
"""Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
self._memstats.append_model_data('cuda', cuda_mem)
self._memstats.append_model_data('cpu', cpu_mem)
raise NotImplementedError("MemStatsCollector has not implemented record_model_data_volume")
def sample_overall_data(self) -> None:
"""Sampling non model data statistics.
"""
Sampling overall and non model data cuda memory statistics.
"""
if self._start_flag and not self.use_outside_memstats:
# overall data recording is after model data recording
if len(self._memstats._model_data_cuda_list) == 0:
return
cuda_overall = self._mem_monitor.finish()
self._memstats.record_max_cuda_overall_data(cuda_overall)
self._memstats.calc_max_cuda_non_model_data()
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
self._memstats.append_non_model_data('cuda')
self._memstats.append_non_model_data('cpu')
self._mem_monitor.start()
if self._start_flag:

View File

@@ -206,7 +206,6 @@ class ShardedModelV2(nn.Module):
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
f.write('CUDA model data (GB)\n')
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
f.write('\n')
f.write('CUDA non model data (GB)\n')
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
@@ -256,8 +255,8 @@ class ShardedModelV2(nn.Module):
# the way to calculate margin space is based on the assumption that
# model data is fixed in cuda during training.
# cuda margin space can be used to store OS.
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
self._memstats_collector._memstats.overall_mem_stats('cuda'))
self._cuda_margin_space = colo_device_memory_capacity(
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
@torch.no_grad()
def _post_backward_operations(self) -> None:

View File

@@ -32,6 +32,8 @@ class GeminiZeROHook(ColoParamOpHook):
self._gemini_manager.adjust_layout(chunks)
for chunk in chunks:
self._chunk_manager.access_chunk(chunk)
# record cuda model data of the current OP
self._gemini_manager.record_model_data_volume()
def post_op(self, params):