diff --git a/colossalai/booster/mixed_precision/__init__.py b/colossalai/booster/mixed_precision/__init__.py index 3cf0ad28c..0df9d8415 100644 --- a/colossalai/booster/mixed_precision/__init__.py +++ b/colossalai/booster/mixed_precision/__init__.py @@ -1,17 +1,19 @@ from .bf16 import BF16MixedPrecision from .fp8 import FP8MixedPrecision from .fp16_apex import FP16ApexMixedPrecision +from .fp16_naive import FP16NaiveMixedPrecision from .fp16_torch import FP16TorchMixedPrecision from .mixed_precision_base import MixedPrecision __all__ = [ 'MixedPrecision', 'mixed_precision_factory', 'FP16_Apex_MixedPrecision', 'FP16_Torch_MixedPrecision', - 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision' + 'FP32_MixedPrecision', 'BF16_MixedPrecision', 'FP8_MixedPrecision', 'FP16NaiveMixedPrecision' ] _mixed_precision_mapping = { 'fp16': FP16TorchMixedPrecision, 'fp16_apex': FP16ApexMixedPrecision, + 'fp16_naive': FP16NaiveMixedPrecision, 'bf16': BF16MixedPrecision, 'fp8': FP8MixedPrecision } diff --git a/colossalai/booster/mixed_precision/fp16_naive.py b/colossalai/booster/mixed_precision/fp16_naive.py new file mode 100644 index 000000000..ef1ec1f42 --- /dev/null +++ b/colossalai/booster/mixed_precision/fp16_naive.py @@ -0,0 +1,5 @@ +from .mixed_precision_base import MixedPrecision + + +class FP16NaiveMixedPrecision(MixedPrecision): + pass