mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-11 12:51:55 +00:00
[Gemini] rename ParamTracerWrapper -> RuntimeMemTracer (#2073)
This commit is contained in:
parent
9f828ef36f
commit
223332ff7e
@ -1,13 +1,14 @@
|
|||||||
import torch.nn
|
import torch.nn
|
||||||
|
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
|
||||||
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook, GradHook
|
|
||||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||||
|
from colossalai.gemini.ophooks.param_trace_hook import GradHook, ParamTracerHook
|
||||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||||
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
|
|
||||||
__all__ = ['ParamTracerWrapper']
|
__all__ = ['RuntimeMemTracer']
|
||||||
|
|
||||||
class ParamTracerWrapper():
|
|
||||||
|
class RuntimeMemTracer():
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -25,12 +26,18 @@ class ParamTracerWrapper():
|
|||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
def _save_param_data_on_cpu(self):
|
def _backup_params(self):
|
||||||
|
"""
|
||||||
|
The function is called before forward. Backup model params on cpu.
|
||||||
|
"""
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu")
|
self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu")
|
||||||
self.cpu_param_data_dict[p].copy_(p.data)
|
self.cpu_param_data_dict[p].copy_(p.data)
|
||||||
|
|
||||||
def _restore_param_data(self):
|
def _restore_params(self):
|
||||||
|
"""
|
||||||
|
This function is called after backward. Restore model params.
|
||||||
|
"""
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad)
|
p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad)
|
||||||
p.data.copy_(self.cpu_param_data_dict[p])
|
p.data.copy_(self.cpu_param_data_dict[p])
|
||||||
@ -38,7 +45,7 @@ class ParamTracerWrapper():
|
|||||||
|
|
||||||
def _pre_forward(self):
|
def _pre_forward(self):
|
||||||
self._clear_cuda_mem_info()
|
self._clear_cuda_mem_info()
|
||||||
self._save_param_data_on_cpu()
|
self._backup_params()
|
||||||
self.grad_hook.register_grad_hook()
|
self.grad_hook.register_grad_hook()
|
||||||
self.param_op_hook.mem_monitor.start()
|
self.param_op_hook.mem_monitor.start()
|
||||||
|
|
||||||
@ -60,7 +67,7 @@ class ParamTracerWrapper():
|
|||||||
last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
|
last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
|
||||||
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
|
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
|
||||||
self.grad_hook.remove_grad_hook()
|
self.grad_hook.remove_grad_hook()
|
||||||
self._restore_param_data()
|
self._restore_params()
|
||||||
|
|
||||||
def _clear_cuda_mem_info(self):
|
def _clear_cuda_mem_info(self):
|
||||||
GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
|
GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
|
@ -1,11 +1,15 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper
|
|
||||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||||
|
from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from tests.components_to_test import run_fwd_bwd
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torch.half):
|
def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torch.half):
|
||||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||||
if criterion:
|
if criterion:
|
||||||
@ -16,9 +20,9 @@ def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torc
|
|||||||
loss = loss.to(dtype)
|
loss = loss.to(dtype)
|
||||||
model.backward(loss)
|
model.backward(loss)
|
||||||
|
|
||||||
|
|
||||||
def run_param_wrapper_testing():
|
def run_param_wrapper_testing():
|
||||||
test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
|
test_models = ['simple_net', 'repeated_computed_layers', 'nested_model']
|
||||||
|
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
model_builder, train_dataloader, _, _, criterion = get_components_func()
|
||||||
@ -26,7 +30,8 @@ def run_param_wrapper_testing():
|
|||||||
with ColoInitContext(device=torch.device('cpu')):
|
with ColoInitContext(device=torch.device('cpu')):
|
||||||
model = model_builder(checkpoint=False)
|
model = model_builder(checkpoint=False)
|
||||||
|
|
||||||
model = ParamTracerWrapper(model)
|
model_bk = deepcopy(model)
|
||||||
|
runtime_mem_tracer = RuntimeMemTracer(model)
|
||||||
|
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
if i > 1:
|
if i > 1:
|
||||||
@ -34,7 +39,10 @@ def run_param_wrapper_testing():
|
|||||||
data = data.cuda()
|
data = data.cuda()
|
||||||
label = label.cuda()
|
label = label.cuda()
|
||||||
|
|
||||||
run_fwd_bwd(model, data, label, criterion, False)
|
run_fwd_bwd(runtime_mem_tracer, data, label, criterion, False)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model_bk.parameters(), model.parameters()):
|
||||||
|
torch.allclose(p1.to(torch.half), p2)
|
||||||
|
|
||||||
cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024**2
|
cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024**2
|
||||||
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
||||||
@ -43,6 +51,5 @@ def run_param_wrapper_testing():
|
|||||||
del model
|
del model
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
run_param_wrapper_testing()
|
run_param_wrapper_testing()
|
Loading…
Reference in New Issue
Block a user