diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 0e35e694d..7420da8f4 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -202,12 +202,17 @@ class ZeroDDP(ColoDDP): module (torch.nn.Module): Module to apply ZeRO-DP. gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space. For more details, see the API reference of ``GeminiManager``. + force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False. """ - def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None: + def __init__(self, + module: torch.nn.Module, + gemini_manager: GeminiManager, + force_outputs_fp32: bool = False) -> None: super().__init__(module, process_group=gemini_manager.chunk_manager.process_group) self.gemini_manager = gemini_manager self.chunk_manager = gemini_manager.chunk_manager + self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = ZeROHookV2(gemini_manager) self.fp32_params: List[ColoParameter] = [] self.overflow_counter = 0 @@ -235,7 +240,9 @@ class ZeroDDP(ColoDDP): with ParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) self.chunk_manager.exec_lazy_release() - return _cast_float(outputs, torch.float) + if self.force_outputs_fp32: + return _cast_float(outputs, torch.float) + return outputs def _setup_grads_ptr(self): for p in self.module.parameters():