[Gemini] ParamOpHook -> ColoParamOpHook (#2080)

This commit is contained in:
Jiarui Fang
2022-12-05 17:11:06 +08:00
committed by GitHub
parent 4f21c9e8d9
commit b3b89865e2
7 changed files with 37 additions and 36 deletions

View File

@@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
@@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter(*args)
with ParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
@@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
self.gemini_manager.post_iter()
def backward(self, loss: torch.Tensor):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
loss.backward()
self._post_backward()
def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
torch.autograd.backward(tensor, grad)
self._post_backward()