From 971b16a74f84415ac3cfde80d762149b2e65c036 Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Thu, 22 Aug 2024 03:00:40 +0000 Subject: [PATCH 1/2] fix --- colossalai/booster/plugin/low_level_zero_plugin.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 4188491c2..4082ffada 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -63,7 +63,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): def __init__( - self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False + self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False ) -> None: super().__init__(module) self.dtype = None @@ -77,7 +77,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.module = module self.convert_fn = None self.use_fp8 = use_fp8 - if self.dtype is not None: + if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather self.op_hooks = [] @@ -342,6 +342,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, ) -> None: @@ -372,6 +373,8 @@ class LowLevelZeroPlugin(DPPluginBase): self.lora_enabled = False self.verbose = verbose self.logger = get_dist_logger() + self.cast_inputs = cast_inputs + self.use_fp8 = use_fp8 # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -490,6 +493,7 @@ class LowLevelZeroPlugin(DPPluginBase): model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + cast_inputs=self.cast_inputs, use_fp8=self.use_fp8, ) From a29255417986bc3c938ab92d8fd22e8906aa0e92 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 22 Aug 2024 03:04:43 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/booster/plugin/low_level_zero_plugin.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 4082ffada..42bb49bc9 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -63,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): def __init__( - self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False + self, + module: nn.Module, + precision: str, + overlap_allgather: bool = False, + cast_inputs: bool = True, + use_fp8: bool = False, ) -> None: super().__init__(module) self.dtype = None