mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-21 21:22:04 +00:00
[zero] ZeroDDP supports controlling outputs' dtype (#1399)
This commit is contained in:
parent
4e98e938ce
commit
04c9a86af8
@ -202,12 +202,17 @@ class ZeroDDP(ColoDDP):
|
|||||||
module (torch.nn.Module): Module to apply ZeRO-DP.
|
module (torch.nn.Module): Module to apply ZeRO-DP.
|
||||||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||||
For more details, see the API reference of ``GeminiManager``.
|
For more details, see the API reference of ``GeminiManager``.
|
||||||
|
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
def __init__(self,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
gemini_manager: GeminiManager,
|
||||||
|
force_outputs_fp32: bool = False) -> None:
|
||||||
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
||||||
self.gemini_manager = gemini_manager
|
self.gemini_manager = gemini_manager
|
||||||
self.chunk_manager = gemini_manager.chunk_manager
|
self.chunk_manager = gemini_manager.chunk_manager
|
||||||
|
self.force_outputs_fp32 = force_outputs_fp32
|
||||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||||
self.fp32_params: List[ColoParameter] = []
|
self.fp32_params: List[ColoParameter] = []
|
||||||
self.overflow_counter = 0
|
self.overflow_counter = 0
|
||||||
@ -235,7 +240,9 @@ class ZeroDDP(ColoDDP):
|
|||||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
self.chunk_manager.exec_lazy_release()
|
self.chunk_manager.exec_lazy_release()
|
||||||
return _cast_float(outputs, torch.float)
|
if self.force_outputs_fp32:
|
||||||
|
return _cast_float(outputs, torch.float)
|
||||||
|
return outputs
|
||||||
|
|
||||||
def _setup_grads_ptr(self):
|
def _setup_grads_ptr(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
|
Loading…
Reference in New Issue
Block a user