mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-23 10:30:03 +00:00
[Gemini] ParamOpHook -> ColoParamOpHook (#2080)
This commit is contained in:
@@ -12,7 +12,7 @@ from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.parallel.utils import get_temp_total_chunk_on_cuda
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor.colo_parameter import ColoParameter, ColoTensor, ColoTensorSpec
|
||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.utils.gemini_hook import GeminiZeROHook
|
||||
|
||||
@@ -259,7 +259,7 @@ class ZeroDDP(ColoDDP):
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
@@ -280,12 +280,12 @@ class ZeroDDP(ColoDDP):
|
||||
self.gemini_manager.post_iter()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
with self.param_op_hook.switch_to_backward(), ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
torch.autograd.backward(tensor, grad)
|
||||
self._post_backward()
|
||||
|
||||
|
Reference in New Issue
Block a user