mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2026-01-29 21:49:54 +00:00
[Gemini] update non model data calculation method (#2126)
This commit is contained in:
@@ -11,13 +11,19 @@ class MemStats(object):
|
||||
"""
|
||||
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
||||
"""
|
||||
# p -> list of non_model data volumn visied in order.
|
||||
|
||||
# (preop_moment, List[param])
|
||||
# (preop_step, List[param])
|
||||
self._step_param_dict = dict()
|
||||
# (param, List[preop_step])
|
||||
self._param_step_dict = dict()
|
||||
# (preop_step, non_model_data)
|
||||
self._step_nmd_dict = dict()
|
||||
self._param_runtime_order = OrderedParamGenerator()
|
||||
|
||||
# (param, List[preop_moment])
|
||||
self._preop_step = 0
|
||||
|
||||
self._prev_overall_cuda = -1
|
||||
self._prev_md_cuda = -1
|
||||
# old version
|
||||
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
||||
|
||||
self._model_data_cuda_list = []
|
||||
@@ -29,9 +35,15 @@ class MemStats(object):
|
||||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
|
||||
self._param_runtime_order = OrderedParamGenerator()
|
||||
def record_max_cuda_non_model_data(self):
|
||||
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
|
||||
self._step_nmd_dict[self._preop_step] = self._prev_overall_cuda - self._prev_md_cuda
|
||||
|
||||
self._preop_step = 0
|
||||
def record_max_cuda_model_data(self, val):
|
||||
self._prev_md_cuda = val
|
||||
|
||||
def record_max_cuda_overall_data(self, val):
|
||||
self._prev_overall_cuda = val
|
||||
|
||||
def param_order(self):
|
||||
if self._param_runtime_order.is_empty():
|
||||
@@ -168,4 +180,8 @@ class MemStats(object):
|
||||
self._param_runtime_order.clear()
|
||||
self._step_param_dict.clear()
|
||||
self._param_step_dict.clear()
|
||||
self._step_nmd_dict.clear()
|
||||
self._preop_step = 0
|
||||
|
||||
self._prev_overall_cuda = -1
|
||||
self._prev_md_cuda = -1
|
||||
|
||||
@@ -64,7 +64,16 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||
raise NotImplementedError("Only free cuda memory")
|
||||
free_storage(p.data)
|
||||
|
||||
def _allocate_params_on_cuda(self, params):
|
||||
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
|
||||
"""
|
||||
move params to cuda
|
||||
|
||||
Args:
|
||||
params (List[torch.nn.Parameter]): target params
|
||||
|
||||
Raises:
|
||||
NotImplementedError: raise error when param has cpu grad
|
||||
"""
|
||||
for p in params:
|
||||
cur_dev = p.data.device.type
|
||||
if cur_dev == "cpu":
|
||||
@@ -78,6 +87,9 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||
alloc_storage(p.data)
|
||||
|
||||
def sample_model_data(self, params):
|
||||
"""
|
||||
get cuda model data used by params
|
||||
"""
|
||||
data_volume = self._grad_stats.unreleased_grad_volume
|
||||
for p in params:
|
||||
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
||||
@@ -89,14 +101,21 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
||||
self._grad_stats.unreleased_grad_flag[p] = True
|
||||
self._memstats.append_model_data('cuda', data_volume)
|
||||
# record max non model data used for this Op
|
||||
self._memstats.record_max_cuda_model_data(data_volume)
|
||||
|
||||
def pre_op(self, params):
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
last_model_data_val = self._memstats.last_model_data('cuda')
|
||||
if last_model_data_val is not None:
|
||||
self._memstats.append_non_model_data('cuda', cuda_volume - last_model_data_val)
|
||||
# get overall cuda data.
|
||||
max_cuda_vol_of_period = self.mem_monitor.finish()
|
||||
# record max cuda overall data for prev Op.
|
||||
self._memstats.record_max_cuda_overall_data(max_cuda_vol_of_period)
|
||||
self._memstats.record_max_cuda_non_model_data()
|
||||
max_cuda_model_data_val = self._memstats.last_model_data('cuda')
|
||||
if max_cuda_model_data_val is not None:
|
||||
self._memstats.append_non_model_data('cuda', max_cuda_vol_of_period - max_cuda_model_data_val)
|
||||
self._allocate_params_on_cuda(params)
|
||||
self.sample_model_data(params)
|
||||
|
||||
self.mem_monitor.start()
|
||||
self._memstats.increase_preop_step(params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user