mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-04 10:34:41 +00:00
[Gemini] update API of the chunkmemstatscollector. (#2129)
This commit is contained in:
@@ -55,7 +55,7 @@ class GeminiManager:
|
||||
|
||||
get the memory statistics during training.
|
||||
The stats could be collected by a runtime memory tracer, or collected by the GeminiManager.
|
||||
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||
Note, for the latter, you can not access the memstats before warmup iteration finishes.
|
||||
"""
|
||||
if self._premade_memstats_:
|
||||
return self._memstats
|
||||
|
@@ -11,18 +11,25 @@ from .memstats_collector import MemStatsCollector
|
||||
class ChunkMemStatsCollector(MemStatsCollector):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, memstats: Optional[MemStats] = None) -> None:
|
||||
"""
|
||||
|
||||
Memory Statistic Collector for Chunks.
|
||||
|
||||
Args:
|
||||
chunk_manager (ChunkManager): the chunk manager.
|
||||
memstats (Optional[MemStats], optional): memory statistics collected by RMT. Defaults to None.
|
||||
"""
|
||||
super().__init__(memstats)
|
||||
self._chunk_manager = chunk_manager
|
||||
|
||||
# override
|
||||
def record_model_data_volume(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
record model data volumn on cuda and cpu.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
cuda_mem = self._chunk_manager.total_mem['cuda']
|
||||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||
self._memstats.append_model_data('cuda', cuda_mem)
|
||||
self._memstats.append_model_data('cpu', cpu_mem)
|
||||
self._memstats.record_max_cuda_model_data(cuda_mem)
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
|
@@ -22,6 +22,7 @@ class MemStats(object):
|
||||
self._preop_step = 0
|
||||
|
||||
self._prev_overall_cuda = -1
|
||||
self._max_overall_cuda = 0
|
||||
self._prev_md_cuda = -1
|
||||
|
||||
# old version
|
||||
@@ -46,6 +47,11 @@ class MemStats(object):
|
||||
|
||||
def record_max_cuda_overall_data(self, val):
|
||||
self._prev_overall_cuda = val
|
||||
self._max_overall_cuda = max(self._max_overall_cuda, val)
|
||||
|
||||
@property
|
||||
def max_overall_cuda(self):
|
||||
return self._max_overall_cuda
|
||||
|
||||
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
|
||||
"""
|
||||
@@ -85,67 +91,6 @@ class MemStats(object):
|
||||
else:
|
||||
return self._param_runtime_order
|
||||
|
||||
## APIs to be depracated
|
||||
def append_overall_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._overall_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
self._overall_cpu_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def append_model_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._model_data_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
self._model_data_cpu_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def last_model_data(self, device_type: str):
|
||||
if len(self._model_data_cuda_list) == 0:
|
||||
return None
|
||||
if device_type == 'cuda':
|
||||
return self._model_data_cuda_list[-1]
|
||||
elif device_type == 'cpu':
|
||||
return self._model_data_cpu_list[-1]
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def append_non_model_data(self, device_type: str, val=None):
|
||||
if device_type == 'cuda':
|
||||
if val is None:
|
||||
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
|
||||
return
|
||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||
else:
|
||||
self._non_model_data_cuda_list.append(val)
|
||||
elif device_type == 'cpu':
|
||||
if val is None:
|
||||
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
|
||||
return
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
else:
|
||||
self._non_model_data_cuda_list.append(val)
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def overall_mem_stats(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._overall_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._overall_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._model_data_cuda_list
|
||||
elif device_type == 'cpu':
|
||||
return self._model_data_cpu_list
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def non_model_data_list(self, device_type: str) -> List[int]:
|
||||
if device_type == 'cuda':
|
||||
return self._non_model_data_cuda_list
|
||||
|
@@ -59,6 +59,7 @@ class MemStatsCollector:
|
||||
return [t - self._sampling_time[0] for t in self._sampling_time]
|
||||
|
||||
def start_collection(self):
|
||||
print('start collection')
|
||||
self._start_flag = True
|
||||
self._mem_monitor.start()
|
||||
|
||||
@@ -68,31 +69,24 @@ class MemStatsCollector:
|
||||
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
||||
self._start_flag = False
|
||||
self._mem_monitor.finish()
|
||||
print(f'finish_collection {self._step_total}')
|
||||
|
||||
# deprecated
|
||||
def record_model_data_volume(self) -> None:
|
||||
"""Sampling model data statistics.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
||||
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
||||
self._memstats.append_model_data('cuda', cuda_mem)
|
||||
self._memstats.append_model_data('cpu', cpu_mem)
|
||||
raise NotImplementedError("MemStatsCollector has not implemented record_model_data_volume")
|
||||
|
||||
def sample_overall_data(self) -> None:
|
||||
"""Sampling non model data statistics.
|
||||
"""
|
||||
Sampling overall and non model data cuda memory statistics.
|
||||
"""
|
||||
if self._start_flag and not self.use_outside_memstats:
|
||||
# overall data recording is after model data recording
|
||||
if len(self._memstats._model_data_cuda_list) == 0:
|
||||
return
|
||||
cuda_overall = self._mem_monitor.finish()
|
||||
self._memstats.record_max_cuda_overall_data(cuda_overall)
|
||||
self._memstats.calc_max_cuda_non_model_data()
|
||||
|
||||
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
|
||||
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
|
||||
|
||||
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
|
||||
|
||||
self._memstats.append_non_model_data('cuda')
|
||||
self._memstats.append_non_model_data('cpu')
|
||||
self._mem_monitor.start()
|
||||
|
||||
if self._start_flag:
|
||||
|
@@ -206,7 +206,6 @@ class ShardedModelV2(nn.Module):
|
||||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device()) / 1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device()) / 1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.model_data_list('cuda')))
|
||||
f.write('\n')
|
||||
f.write('CUDA non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector._memstats.non_model_data_list('cuda')))
|
||||
@@ -256,8 +255,8 @@ class ShardedModelV2(nn.Module):
|
||||
# the way to calculate margin space is based on the assumption that
|
||||
# model data is fixed in cuda during training.
|
||||
# cuda margin space can be used to store OS.
|
||||
self._cuda_margin_space = colo_device_memory_capacity(get_current_device()) - max(
|
||||
self._memstats_collector._memstats.overall_mem_stats('cuda'))
|
||||
self._cuda_margin_space = colo_device_memory_capacity(
|
||||
get_current_device()) - self._memstats_collector._memstats.max_overall_cuda
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
|
@@ -32,6 +32,8 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||
self._gemini_manager.adjust_layout(chunks)
|
||||
for chunk in chunks:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
# record cuda model data of the current OP
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
def post_op(self, params):
|
||||
|
Reference in New Issue
Block a user