From 406f984063423042e25d0723258530ba506a44a9 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 15 Aug 2024 10:41:22 +0800 Subject: [PATCH] [plugin] add cast inputs option for zero (#6003) --- colossalai/booster/plugin/low_level_zero_plugin.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 66491821c..e4c386a22 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -62,7 +62,9 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + def __init__( + self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True + ) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -74,7 +76,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None: + if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -334,6 +336,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -360,6 +363,7 @@ class LowLevelZeroPlugin(DPPluginBase): ) self.lora_enabled = False self.verbose = verbose + self.cast_inputs = cast_inputs # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -474,7 +478,10 @@ class LowLevelZeroPlugin(DPPluginBase): if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + model, + self.precision, + overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + cast_inputs=self.cast_inputs, ) # TODO: Support Galore + ZeRO