mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 06:00:44 +00:00
[Gemini] fix grad unreleased issue and param recovery issue (#2052)
This commit is contained in:
parent
edf4cd46c5
commit
38ea4ba1bd
@ -106,4 +106,15 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||||||
return self._get_mem_usage()
|
return self._get_mem_usage()
|
||||||
|
|
||||||
|
|
||||||
|
class CudaMemInfo(metaclass=SingletonMeta):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.model_data_list = []
|
||||||
|
self.non_model_data_list = []
|
||||||
|
self.unreleased_grad_flag = {}
|
||||||
|
self.unreleased_grad_volume = 0
|
||||||
|
|
||||||
|
|
||||||
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
GLOBAL_MODEL_DATA_TRACER = ModelDataTracer()
|
||||||
|
|
||||||
|
GLOBAL_CUDA_MEM_INFO = CudaMemInfo()
|
@ -1,11 +1,9 @@
|
|||||||
import torch.nn
|
import torch.nn
|
||||||
|
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook
|
from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook, GradHook
|
||||||
from colossalai.gemini.tensor_utils import free_storage
|
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['ParamTracerWrapper']
|
__all__ = ['ParamTracerWrapper']
|
||||||
|
|
||||||
@ -15,22 +13,33 @@ class ParamTracerWrapper():
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.module = module
|
self.module = module
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.param_op_hook = ParamTracerHook(dtype)
|
self.param_op_hook = ParamTracerHook()
|
||||||
|
self.grad_hook = GradHook(module)
|
||||||
|
self.cpu_param_data_dict = {}
|
||||||
|
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
p.data = p.data.to(dtype)
|
p.data = p.data.to(dtype)
|
||||||
if p.requires_grad:
|
|
||||||
p.register_hook(partial(self.grad_handle))
|
|
||||||
|
|
||||||
self._cast_buffers_to_cuda_dtype()
|
self._cast_buffers_to_cuda_dtype()
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.forward(*args, **kwargs)
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
def grad_handle(self, grad):
|
def _save_param_data_on_cpu(self):
|
||||||
free_storage(grad)
|
for p in self.module.parameters():
|
||||||
|
self.cpu_param_data_dict[p] = torch.empty(p.data.shape, dtype=self.dtype, device="cpu")
|
||||||
|
self.cpu_param_data_dict[p].copy_(p.data)
|
||||||
|
|
||||||
|
def _restore_param_data(self):
|
||||||
|
for p in self.module.parameters():
|
||||||
|
p.data = torch.empty(p.data.shape, dtype=self.dtype, device="cpu", requires_grad=p.data.requires_grad)
|
||||||
|
p.data.copy_(self.cpu_param_data_dict[p])
|
||||||
|
self.cpu_param_data_dict.clear()
|
||||||
|
|
||||||
def _pre_forward(self):
|
def _pre_forward(self):
|
||||||
|
self._clear_cuda_mem_info()
|
||||||
|
self._save_param_data_on_cpu()
|
||||||
|
self.grad_hook.register_grad_hook()
|
||||||
self.param_op_hook.mem_monitor.start()
|
self.param_op_hook.mem_monitor.start()
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
@ -48,8 +57,16 @@ class ParamTracerWrapper():
|
|||||||
|
|
||||||
def _post_backward(self):
|
def _post_backward(self):
|
||||||
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
||||||
last_model_data = self.param_op_hook._model_data_list[-1]
|
last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
|
||||||
self.param_op_hook._non_model_data_list.append(cuda_volume - last_model_data)
|
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
|
||||||
|
self.grad_hook.remove_grad_hook()
|
||||||
|
self._restore_param_data()
|
||||||
|
|
||||||
|
def _clear_cuda_mem_info(self):
|
||||||
|
GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
|
||||||
|
GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear()
|
||||||
|
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear()
|
||||||
|
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0
|
||||||
|
|
||||||
def _cast_buffers_to_cuda_dtype(self):
|
def _cast_buffers_to_cuda_dtype(self):
|
||||||
for buffer in self.module.buffers():
|
for buffer in self.module.buffers():
|
||||||
|
@ -8,6 +8,7 @@ import torch
|
|||||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||||
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
|
from colossalai.gemini.tensor_utils import free_storage, alloc_storage
|
||||||
|
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
class TrainingPhase(Enum):
|
||||||
@ -15,42 +16,69 @@ class TrainingPhase(Enum):
|
|||||||
BACKWARD = 1
|
BACKWARD = 1
|
||||||
|
|
||||||
|
|
||||||
|
class GradHook():
|
||||||
|
def __init__(self, module: torch.nn.Module):
|
||||||
|
self.module = module
|
||||||
|
self.grad_hook_list = []
|
||||||
|
|
||||||
|
def grad_handle(self, p, grad):
|
||||||
|
assert GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]
|
||||||
|
free_storage(grad)
|
||||||
|
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume -= grad.numel() * grad.element_size()
|
||||||
|
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
|
||||||
|
|
||||||
|
def register_grad_hook(self):
|
||||||
|
for p in self.module.parameters():
|
||||||
|
if p.requires_grad:
|
||||||
|
self.grad_hook_list.append(p.register_hook(partial(self.grad_handle, p)))
|
||||||
|
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = False
|
||||||
|
|
||||||
|
def remove_grad_hook(self):
|
||||||
|
for hook in self.grad_hook_list:
|
||||||
|
hook.remove()
|
||||||
|
|
||||||
|
|
||||||
class ParamTracerHook(ParamOpHook):
|
class ParamTracerHook(ParamOpHook):
|
||||||
|
|
||||||
def __init__(self, dtype: torch.dtype = torch.half) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._training_phase = TrainingPhase.FORWARD
|
self._training_phase = TrainingPhase.FORWARD
|
||||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||||
self._non_model_data_list = []
|
|
||||||
self._model_data_list = []
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
def _free_cuda_params(self, params):
|
def _free_cuda_params(self, params):
|
||||||
for p in params:
|
for p in params:
|
||||||
|
if p.data.device.type == "cpu":
|
||||||
|
raise NotImplementedError("Only free cuda memory")
|
||||||
free_storage(p.data)
|
free_storage(p.data)
|
||||||
|
|
||||||
def _allocate_params_on_cuda(self, params):
|
def _allocate_params_on_cuda(self, params):
|
||||||
for p in params:
|
for p in params:
|
||||||
cur_dev = p.data.device.type
|
cur_dev = p.data.device.type
|
||||||
if cur_dev == "cpu":
|
if cur_dev == "cpu":
|
||||||
# p.data = p.data.to("cuda")
|
if p.grad is not None and p.grad.device.type == "cpu":
|
||||||
p.data = torch.randn(p.data.shape, device="cuda", dtype=self.dtype)
|
raise NotImplementedError("Only run in forward propagation")
|
||||||
|
p.data = torch.empty(p.data.shape, device="cuda", dtype=p.data.dtype,
|
||||||
|
requires_grad=p.data.requires_grad)
|
||||||
elif cur_dev == "cuda":
|
elif cur_dev == "cuda":
|
||||||
alloc_storage(p.data)
|
alloc_storage(p.data)
|
||||||
|
|
||||||
def sample_model_data(self, params):
|
def sample_model_data(self, params):
|
||||||
data_volume = 0
|
data_volume = GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume
|
||||||
for p in params:
|
for p in params:
|
||||||
data_volume += p.data.numel() * p.data.element_size()
|
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
||||||
if self._training_phase == TrainingPhase.BACKWARD:
|
data_volume += cur_model_data_volume
|
||||||
# add param.grad, actually param.grad is None in this time
|
if self._training_phase == TrainingPhase.BACKWARD and p.requires_grad:
|
||||||
data_volume *= 2
|
# add param.grad, actually param.grad is None in this time
|
||||||
self._model_data_list.append(data_volume)
|
data_volume += cur_model_data_volume
|
||||||
|
if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]:
|
||||||
|
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume
|
||||||
|
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True
|
||||||
|
GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume)
|
||||||
|
|
||||||
def pre_op(self, params):
|
def pre_op(self, params):
|
||||||
cuda_volume = self.mem_monitor.finish()
|
cuda_volume = self.mem_monitor.finish()
|
||||||
if len(self._model_data_list):
|
if len(GLOBAL_CUDA_MEM_INFO.model_data_list):
|
||||||
self._non_model_data_list.append(cuda_volume - self._model_data_list[-1])
|
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1])
|
||||||
self._allocate_params_on_cuda(params)
|
self._allocate_params_on_cuda(params)
|
||||||
self.sample_model_data(params)
|
self.sample_model_data(params)
|
||||||
self.mem_monitor.start()
|
self.mem_monitor.start()
|
||||||
|
@ -2,6 +2,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper
|
from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper
|
||||||
|
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
@ -35,9 +36,9 @@ def run_param_wrapper_testing():
|
|||||||
|
|
||||||
run_fwd_bwd(model, data, label, criterion, False)
|
run_fwd_bwd(model, data, label, criterion, False)
|
||||||
|
|
||||||
cuda_non_model_data_list = np.array(model.param_op_hook._non_model_data_list) / 1024 ** 2
|
cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024 ** 2
|
||||||
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
||||||
# print(model.param_op_hook._non_model_data_list)
|
# print(GLOBAL_CUDA_MEM_INFO.non_model_data_list)
|
||||||
|
|
||||||
del model
|
del model
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user