mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 05:29:36 +00:00
[Gemini] param_tracer_wrapper and test case (#2009)
This commit is contained in:
parent
1438993113
commit
0160a62a3c
@ -4,8 +4,9 @@ from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
|
|||||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||||
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||||
from .module_tracer_wrapper import MemtracerWrapper # isort:skip
|
from .module_tracer_wrapper import MemtracerWrapper # isort:skip
|
||||||
|
from .param_tracer_wrapper import ParamWrapper # isort:skip
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||||
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper'
|
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper', 'ParamWrapper'
|
||||||
]
|
]
|
||||||
|
51
colossalai/gemini/memory_tracer/param_tracer_wrapper.py
Normal file
51
colossalai/gemini/memory_tracer/param_tracer_wrapper.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import torch.nn
|
||||||
|
|
||||||
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
|
from colossalai.gemini.ophooks import ParamMemHook
|
||||||
|
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||||
|
|
||||||
|
|
||||||
|
class ParamWrapper():
|
||||||
|
|
||||||
|
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
||||||
|
super().__init__()
|
||||||
|
self.module = module
|
||||||
|
self.dtype = dtype
|
||||||
|
self.param_op_hook = ParamMemHook()
|
||||||
|
|
||||||
|
for p in module.parameters():
|
||||||
|
assert isinstance(p, ColoParameter)
|
||||||
|
p.data = p.data.to(dtype)
|
||||||
|
|
||||||
|
self._cast_buffers_to_cuda_dtype()
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def _pre_forward(self):
|
||||||
|
self.param_op_hook.mem_monitor.start()
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype)
|
||||||
|
self.module.zero_grad(set_to_none=True)
|
||||||
|
self._pre_forward()
|
||||||
|
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
|
outputs = self.module(*args, **kwargs)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def backward(self, loss):
|
||||||
|
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
|
loss.backward()
|
||||||
|
self._post_backward()
|
||||||
|
|
||||||
|
def _post_backward(self):
|
||||||
|
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
||||||
|
last_model_data = self.param_op_hook._model_data_list[-1]
|
||||||
|
self.param_op_hook._non_model_data_list.append(cuda_volume - last_model_data)
|
||||||
|
|
||||||
|
def _cast_buffers_to_cuda_dtype(self):
|
||||||
|
for buffer in self.module.buffers():
|
||||||
|
buffer.data = buffer.cuda()
|
||||||
|
if torch.is_floating_point(buffer):
|
||||||
|
buffer.data = buffer.data.to(self.dtype)
|
47
tests/test_gemini/test_mem_tracer_paramOP.py
Normal file
47
tests/test_gemini/test_mem_tracer_paramOP.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamWrapper
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False):
|
||||||
|
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||||
|
if criterion:
|
||||||
|
y = model(data)
|
||||||
|
loss = criterion(y, label)
|
||||||
|
else:
|
||||||
|
loss = model(data, label)
|
||||||
|
loss = loss.float()
|
||||||
|
model.backward(loss)
|
||||||
|
|
||||||
|
def run_param_wrapper_testing():
|
||||||
|
test_models = ['repeated_computed_layers', 'simple_net', 'no_leaf_module', 'bert']
|
||||||
|
|
||||||
|
for model_name in test_models:
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||||
|
|
||||||
|
with ColoInitContext(device=torch.device('cpu')):
|
||||||
|
model = model_builder(checkpoint=False)
|
||||||
|
|
||||||
|
model = ParamWrapper(model)
|
||||||
|
|
||||||
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
|
if i > 1:
|
||||||
|
break
|
||||||
|
data = data.cuda()
|
||||||
|
label = label.cuda()
|
||||||
|
|
||||||
|
run_fwd_bwd(model, data, label, criterion, False)
|
||||||
|
|
||||||
|
cuda_non_model_data_list = np.array(model.param_op_hook._non_model_data_list) / 1024 ** 2
|
||||||
|
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
||||||
|
# print(model.param_op_hook._non_model_data_list)
|
||||||
|
|
||||||
|
del model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
run_param_wrapper_testing()
|
Loading…
Reference in New Issue
Block a user