diff --git a/colossalai/gemini/memory_tracer/param_tracer_wrapper.py b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py new file mode 100644 index 000000000..b6b26fe9a --- /dev/null +++ b/colossalai/gemini/memory_tracer/param_tracer_wrapper.py @@ -0,0 +1,52 @@ +import torch.nn + +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.param_op_hook import ParamOpHookManager +from colossalai.gemini.ophooks.param_trace_hook import ParamTracerHook +from colossalai.nn.parallel.data_parallel import _cast_float + +__all__ = ['ParamTracerWrapper'] + +class ParamTracerWrapper(): + + def __init__(self, module: torch.nn.Module, dtype: torch.dtype = torch.half): + super().__init__() + self.module = module + self.dtype = dtype + self.param_op_hook = ParamTracerHook() + + for p in module.parameters(): + assert isinstance(p, ColoParameter) + p.data = p.data.to(dtype) + + self._cast_buffers_to_cuda_dtype() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _pre_forward(self): + self.param_op_hook.mem_monitor.start() + + def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, self.dtype), _cast_float(kwargs, self.dtype) + self.module.zero_grad(set_to_none=True) + self._pre_forward() + with ParamOpHookManager.use_hooks(self.param_op_hook): + outputs = self.module(*args, **kwargs) + return outputs + + def backward(self, loss): + with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook): + loss.backward() + self._post_backward() + + def _post_backward(self): + cuda_volume = self.param_op_hook.mem_monitor.finish() + last_model_data = self.param_op_hook._model_data_list[-1] + self.param_op_hook._non_model_data_list.append(cuda_volume - last_model_data) + + def _cast_buffers_to_cuda_dtype(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.data.to(self.dtype) \ No newline at end of file diff --git a/colossalai/gemini/ophooks/param_trace_hook.py b/colossalai/gemini/ophooks/param_trace_hook.py new file mode 100644 index 000000000..970dcb5c4 --- /dev/null +++ b/colossalai/gemini/ophooks/param_trace_hook.py @@ -0,0 +1,81 @@ +from contextlib import contextmanager +from enum import Enum +from functools import partial +from typing import List + +import torch + +from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor +from colossalai.tensor.param_op_hook import ParamOpHook + + +class TrainingPhase(Enum): + FORWARD = 0 + BACKWARD = 1 + + +class ParamTracerHook(ParamOpHook): + + def __init__(self) -> None: + super().__init__() + self._training_phase = TrainingPhase.FORWARD + self.mem_monitor = SyncCudaMemoryMonitor() + self._non_model_data_list = [] + self._model_data_list = [] + + def _move_params_to_dev(self, params, dev: str) -> int: + assert isinstance(dev, str), f"device should be a str not torch.device" + comm_volume = 0 + for p in params: + if p.data.device.type != dev: + p.data = p.data.to(dev) + comm_volume += p.data.numel() * p.data.element_size() + if p.grad is not None: + if p.grad.device.type != dev: + p.grad = p.grad.to(dev) + comm_volume += p.grad.numel() * p.grad.element_size() + return comm_volume + + def sample_model_data(self, params): + data_volume = 0 + for p in params: + data_volume += p.data.numel() * p.data.element_size() + if self._training_phase == TrainingPhase.BACKWARD: + # add param.grad, actually param.grad is None in this time + data_volume *= 2 + self._model_data_list.append(data_volume) + + def pre_op(self, params): + cuda_volume = self.mem_monitor.finish() + if len(self._model_data_list): + self._non_model_data_list.append(cuda_volume - self._model_data_list[-1]) + self._move_params_to_dev(params, 'cuda') + self.sample_model_data(params) + self.mem_monitor.start() + + def post_op(self, params): + self._move_params_to_dev(params, 'cpu') + + def pre_forward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_forward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + def pre_backward(self, params: List[torch.Tensor]) -> None: + self.pre_op(params) + + def post_backward(self, params: List[torch.Tensor]) -> None: + self.post_op(params) + + @contextmanager + def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): + old_training_phase = self._training_phase + try: + self._training_phase = training_phase + yield + finally: + self._training_phase = old_training_phase + + switch_to_backward = switch_training_phase + switch_to_forward = partial(switch_to_backward, training_phase=TrainingPhase.FORWARD) \ No newline at end of file diff --git a/tests/test_gemini/test_param_tracer.py b/tests/test_gemini/test_param_tracer.py new file mode 100644 index 000000000..79f311cb5 --- /dev/null +++ b/tests/test_gemini/test_param_tracer.py @@ -0,0 +1,47 @@ +import numpy as np +import torch + +from colossalai.gemini.memory_tracer.param_tracer_wrapper import ParamTracerWrapper +from colossalai.utils.model.colo_init_context import ColoInitContext +from tests.components_to_test.registry import non_distributed_component_funcs + +def run_fwd_bwd(model, data, label, criterion, enable_autocast=False, dtype=torch.half): + with torch.cuda.amp.autocast(enabled=enable_autocast): + if criterion: + y = model(data) + loss = criterion(y, label) + else: + loss = model(data, label) + loss = loss.to(dtype) + model.backward(loss) + +def run_param_wrapper_testing(): + test_models = ['simple_net'] + + for model_name in test_models: + get_components_func = non_distributed_component_funcs.get_callable(model_name) + model_builder, train_dataloader, _, _, criterion = get_components_func() + + with ColoInitContext(device=torch.device('cpu')): + model = model_builder(checkpoint=False) + + model = ParamTracerWrapper(model) + + for i, (data, label) in enumerate(train_dataloader): + if i > 1: + break + data = data.cuda() + label = label.cuda() + + run_fwd_bwd(model, data, label, criterion, False) + + cuda_non_model_data_list = np.array(model.param_op_hook._non_model_data_list) / 1024 ** 2 + print("cuda_non_model_data_list", len(cuda_non_model_data_list)) + # print(model.param_op_hook._non_model_data_list) + + del model + + + +if __name__ == '__main__': + run_param_wrapper_testing() \ No newline at end of file