diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index a81f9b05d..af135db44 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -435,6 +435,7 @@ class GeminiPlugin(DPPluginBase): enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False. enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. + use_deep_gemm (bool, optional): Whether to use deep_gemm for fp8 matmul. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. """ @@ -479,6 +480,7 @@ class GeminiPlugin(DPPluginBase): enable_jit_fused: bool = False, enable_async_reduce: bool = True, use_fp8: bool = False, + use_deep_gemm: bool = False, verbose: bool = False, fp8_communication: bool = False, ) -> None: @@ -517,6 +519,7 @@ class GeminiPlugin(DPPluginBase): enable_async_reduce=enable_async_reduce, fp8_communication=fp8_communication, use_fp8=use_fp8, + use_deep_gemm=use_deep_gemm, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1e0f7be24..ac8c2ab9d 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -33,7 +33,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp from colossalai.shardformer.policies.base_policy import Policy @@ -70,6 +70,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): custom_policy: Policy, overlap_allgather: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.shard_config = shard_config @@ -80,6 +81,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.require_grad_sync = True self.overlap_allgather = overlap_allgather self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -119,7 +121,10 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): super().__init__(module) self.op_hooks = [] if use_fp8: - self.op_hooks.append(FP8Hook()) + if use_deep_gemm: + self.op_hooks.append(FP8DeepGemmHook()) + else: + self.op_hooks.append(FP8Hook()) if overlap_allgather: self.op_hooks.append(ZeroOpHook()) if use_fp8 or overlap_allgather: @@ -1044,6 +1049,7 @@ class HybridParallelPlugin(PipelinePluginBase): overlap_allgather: bool = False, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, inner_ring_size: int = None, ) -> None: super().__init__() @@ -1097,6 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm if dp_outside: self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) @@ -1323,6 +1330,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy=self.custom_policy, overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), use_fp8=self.use_fp8, + use_deep_gemm=self.use_deep_gemm, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if zero_stage == 0: diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 642969be3..0c02189bb 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -37,7 +37,7 @@ from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.quantization import BnbQuantizationConfig, quantize_model -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero import LowLevelZeroOptimizer @@ -72,6 +72,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: super().__init__(module) self.dtype = None @@ -92,7 +93,10 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): if overlap_allgather: self.op_hooks.append(ZeroOpHook()) if use_fp8: - self.op_hooks.append(FP8Hook()) + if use_deep_gemm: + self.op_hooks.append(FP8DeepGemmHook()) + else: + self.op_hooks.append(FP8Hook()) if overlap_allgather or use_fp8: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: @@ -400,6 +404,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False. verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. + use_deep_gemm (bool, optional): Whether to use deep_gemm matmul. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1. """ @@ -427,6 +432,7 @@ class LowLevelZeroPlugin(DPPluginBase): cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, extra_dp_size: int = 1, ) -> None: super().__init__() @@ -466,6 +472,7 @@ class LowLevelZeroPlugin(DPPluginBase): self.cast_inputs = cast_inputs self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -585,6 +592,7 @@ class LowLevelZeroPlugin(DPPluginBase): overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], cast_inputs=self.cast_inputs, use_fp8=self.use_fp8, + use_deep_gemm=self.use_deep_gemm, ) # TODO: Support Galore + ZeRO diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index a733fc5f5..8e0ac6bcf 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -171,6 +171,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64. overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism. use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False. + use_deep_gemm (bool, optional): Whether to use deep gemm for fp8 training. Defaults to False. fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False. """ @@ -222,6 +223,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): overlap_allgather: bool = False, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: self.logger = get_dist_logger() if overlap_communication or zero_stage == 2: @@ -359,6 +361,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.mixed_dp_group = self.dp_group self.use_fp8 = use_fp8 + self.use_deep_gemm = use_deep_gemm self.shard_config = ShardConfig( tensor_parallel_process_group=self.tp_group, @@ -465,6 +468,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ddp_config=self.ddp_config, custom_policy=self.custom_policy, use_fp8=self.use_fp8, + use_deep_gemm=self.use_deep_gemm, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if self.zero_stage == 0: diff --git a/colossalai/quantization/fp8_hook.py b/colossalai/quantization/fp8_hook.py index 6171dd755..8569358a4 100644 --- a/colossalai/quantization/fp8_hook.py +++ b/colossalai/quantization/fp8_hook.py @@ -1,6 +1,6 @@ import torch.nn.functional as F -from colossalai.quantization.fp8 import linear_fp8 +from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm from colossalai.tensor.param_op_hook import ColoParamOpHook @@ -21,3 +21,23 @@ class FP8Hook(ColoParamOpHook): if func is F.linear: return linear_fp8 return func + + +class FP8DeepGemmHook(ColoParamOpHook): + + def pre_forward(self, params) -> None: + pass + + def post_forward(self, params) -> None: + pass + + def pre_backward(self, params) -> None: + pass + + def post_backward(self, params) -> None: + pass + + def rewrite_op(self, func): + if func is F.linear: + return linear_fp8_deep_gemm + return func diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 9e89e8827..f4e2ce0ae 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -15,7 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_ from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor import ( distribute_tensor, @@ -101,6 +101,7 @@ class GeminiDDP(ModelWrapper): enable_async_reduce: bool = True, fp8_communication: bool = False, use_fp8: bool = False, + use_deep_gemm: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False @@ -142,7 +143,10 @@ class GeminiDDP(ModelWrapper): self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.hooks = [self.param_op_hook] if use_fp8: - self.hooks.append(FP8Hook()) + if use_deep_gemm: + self.hooks.append(FP8DeepGemmHook()) + else: + self.hooks.append(FP8Hook()) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() diff --git a/docs/source/en/features/mixed_precision_training_with_booster.md b/docs/source/en/features/mixed_precision_training_with_booster.md index 1e17c2bb5..1af074578 100644 --- a/docs/source/en/features/mixed_precision_training_with_booster.md +++ b/docs/source/en/features/mixed_precision_training_with_booster.md @@ -63,7 +63,7 @@ However, there are other operations, like reductions, which require the dynamic We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`. -Currently we only support `fp8` mixed precision training for the `Linear` layer. Please specify the `use_fp8` parameter when create the plugin object. +Currently we only support `fp8` mixed precision training for the `Linear` layer, please specify the `use_fp8` parameter when create the plugin object. `deep_gemm` fp8 matmul is adopted which can be enabled by specifying `use_deep_gemm`. To reduce the communication volume inter nodes in low-bandwidth scenarios, we support FP8 communication compression. Please specify the `fp8_communication` parameter when create the plugin object. diff --git a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md index 93a69830c..83e2adbfb 100644 --- a/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md +++ b/docs/source/zh-Hans/features/mixed_precision_training_with_booster.md @@ -59,7 +59,7 @@ AMP 代表自动混合精度训练。 我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数; 后续将会拓展`bf16`. -我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数。 +我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数,`deep_gemm`fp8矩阵乘法适配请指定`use_deep_gemm`参数。 为了减少低带宽场景下多机之间的通讯负载,我们还支持了FP8通讯。如果您需要使用,请在创建 plugin实例时指定`fp8_communication`参数。 diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 2964f83f4..231173696 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -107,6 +107,7 @@ def main(): parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--use_deep_gemm", action="store_true", default=False, help="for using deep gemm") parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear") parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p") parser.add_argument("--overlap_allgather", action="store_true") @@ -159,6 +160,7 @@ def main(): max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, ) elif args.plugin == "gemini_auto": @@ -173,6 +175,7 @@ def main(): enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, ) elif args.plugin == "fsdp": @@ -252,6 +255,7 @@ def main(): enable_metadata_cache=not args.no_cache, overlap_allgather=args.overlap_allgather, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, scheduler_nodes=scheduler_nodes, **hybrid_kwargs, @@ -271,6 +275,7 @@ def main(): precision="bf16", overlap_p2p=args.overlap_p2p, use_fp8=args.use_fp8, + use_deep_gemm=args.use_deep_gemm, fp8_communication=args.use_fp8_comm, ) else: diff --git a/tests/test_fp8/test_fp8_hook.py b/tests/test_fp8/test_fp8_hook.py index abd5d09e1..a7bc3b7b9 100644 --- a/tests/test_fp8/test_fp8_hook.py +++ b/tests/test_fp8/test_fp8_hook.py @@ -4,8 +4,8 @@ import torch.nn as nn import torch.nn.functional as F from colossalai.accelerator import get_accelerator -from colossalai.quantization.fp8 import linear_fp8 -from colossalai.quantization.fp8_hook import FP8Hook +from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm +from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import get_current_device @@ -20,6 +20,12 @@ def new_linear_fp8(x, w, bias=None): return linear_fp8(x, w, bias) +def new_deepgemm_fp8_gemm(lhs, rhs, out=None): + global TRIGGERED + TRIGGERED = True + return linear_fp8_deep_gemm(lhs, rhs, out) + + class FP8TestHook(FP8Hook): def rewrite_op(self, func): func = super().rewrite_op(func) @@ -30,13 +36,26 @@ class FP8TestHook(FP8Hook): return func -D_IN, D_OUT = 16, 32 +class DeepGemmTestHook(FP8DeepGemmHook): + def rewrite_op(self, func): + func = super().rewrite_op(func) + if func is linear_fp8_deep_gemm: + global REPLACED + REPLACED = True + return new_deepgemm_fp8_gemm + return func + + +D_IN, D_OUT = 128, 128 B, S = 2, 64 DTYPE = torch.bfloat16 @pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") def test_fp8_hook(): + global REPLACED, TRIGGERED + REPLACED = False + TRIGGERED = False # create tensors w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) @@ -48,3 +67,21 @@ def test_fp8_hook(): assert o.shape == (B, S, D_OUT) assert REPLACED assert TRIGGERED + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +def test_fp8_deep_gemm_hook(): + global REPLACED, TRIGGERED + REPLACED = False + TRIGGERED = False + # create tensors + w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) + x = torch.rand(S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) + w.__class__ = ColoParameter + w.__init__(w, requires_grad=True) + hook = DeepGemmTestHook() + with ColoParamOpHookManager.use_hooks(hook): + o = F.linear(x, w) + assert o.shape == (S, D_OUT) + assert REPLACED + assert TRIGGERED