[deep_gemm] add deep_gemm hook

This commit is contained in:
hxwang 2025-03-17 10:47:25 +08:00
parent 9ef62832fd
commit 16e46efd79
No known key found for this signature in database
GPG Key ID: 0EC383D418F0B9F8
10 changed files with 101 additions and 12 deletions

View File

@ -435,6 +435,7 @@ class GeminiPlugin(DPPluginBase):
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
use_deep_gemm (bool, optional): Whether to use deep_gemm for fp8 matmul. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
"""
@ -479,6 +480,7 @@ class GeminiPlugin(DPPluginBase):
enable_jit_fused: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
use_deep_gemm: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
@ -517,6 +519,7 @@ class GeminiPlugin(DPPluginBase):
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
use_deep_gemm=use_deep_gemm,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,

View File

@ -33,7 +33,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
from colossalai.shardformer.policies.base_policy import Policy
@ -70,6 +70,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
custom_policy: Policy,
overlap_allgather: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
@ -80,6 +81,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@ -119,7 +121,10 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
super().__init__(module)
self.op_hooks = []
if use_fp8:
self.op_hooks.append(FP8Hook())
if use_deep_gemm:
self.op_hooks.append(FP8DeepGemmHook())
else:
self.op_hooks.append(FP8Hook())
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8 or overlap_allgather:
@ -1044,6 +1049,7 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
inner_ring_size: int = None,
) -> None:
super().__init__()
@ -1097,6 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
@ -1323,6 +1330,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8,
use_deep_gemm=self.use_deep_gemm,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:

View File

@ -37,7 +37,7 @@ from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero import LowLevelZeroOptimizer
@ -72,6 +72,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
use_deep_gemm: bool = False,
) -> None:
super().__init__(module)
self.dtype = None
@ -92,7 +93,10 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if overlap_allgather:
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if use_deep_gemm:
self.op_hooks.append(FP8DeepGemmHook())
else:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
@ -400,6 +404,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
use_deep_gemm (bool, optional): Whether to use deep_gemm matmul. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
"""
@ -427,6 +432,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
extra_dp_size: int = 1,
) -> None:
super().__init__()
@ -466,6 +472,7 @@ class LowLevelZeroPlugin(DPPluginBase):
self.cast_inputs = cast_inputs
self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@ -585,6 +592,7 @@ class LowLevelZeroPlugin(DPPluginBase):
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8,
use_deep_gemm=self.use_deep_gemm,
)
# TODO: Support Galore + ZeRO

View File

@ -171,6 +171,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
use_deep_gemm (bool, optional): Whether to use deep gemm for fp8 training. Defaults to False.
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
"""
@ -222,6 +223,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
) -> None:
self.logger = get_dist_logger()
if overlap_communication or zero_stage == 2:
@ -359,6 +361,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
self.mixed_dp_group = self.dp_group
self.use_fp8 = use_fp8
self.use_deep_gemm = use_deep_gemm
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
@ -465,6 +468,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
use_fp8=self.use_fp8,
use_deep_gemm=self.use_deep_gemm,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:

View File

@ -1,6 +1,6 @@
import torch.nn.functional as F
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm
from colossalai.tensor.param_op_hook import ColoParamOpHook
@ -21,3 +21,23 @@ class FP8Hook(ColoParamOpHook):
if func is F.linear:
return linear_fp8
return func
class FP8DeepGemmHook(ColoParamOpHook):
def pre_forward(self, params) -> None:
pass
def post_forward(self, params) -> None:
pass
def pre_backward(self, params) -> None:
pass
def post_backward(self, params) -> None:
pass
def rewrite_op(self, func):
if func is F.linear:
return linear_fp8_deep_gemm
return func

View File

@ -15,7 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor import (
distribute_tensor,
@ -101,6 +101,7 @@ class GeminiDDP(ModelWrapper):
enable_async_reduce: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
use_deep_gemm: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -142,7 +143,10 @@ class GeminiDDP(ModelWrapper):
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.hooks = [self.param_op_hook]
if use_fp8:
self.hooks.append(FP8Hook())
if use_deep_gemm:
self.hooks.append(FP8DeepGemmHook())
else:
self.hooks.append(FP8Hook())
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()

View File

@ -63,7 +63,7 @@ However, there are other operations, like reductions, which require the dynamic
We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`.
Currently we only support `fp8` mixed precision training for the `Linear` layer. Please specify the `use_fp8` parameter when create the plugin object.
Currently we only support `fp8` mixed precision training for the `Linear` layer, please specify the `use_fp8` parameter when create the plugin object. `deep_gemm` fp8 matmul is adopted which can be enabled by specifying `use_deep_gemm`.
To reduce the communication volume inter nodes in low-bandwidth scenarios, we support FP8 communication compression. Please specify the `fp8_communication` parameter when create the plugin object.

View File

@ -59,7 +59,7 @@ AMP 代表自动混合精度训练。
我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数; 后续将会拓展`bf16`.
我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数。
我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数`deep_gemm`fp8矩阵乘法适配请指定`use_deep_gemm`参数
为了减少低带宽场景下多机之间的通讯负载我们还支持了FP8通讯。如果您需要使用请在创建 plugin实例时指定`fp8_communication`参数。

View File

@ -107,6 +107,7 @@ def main():
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--use_deep_gemm", action="store_true", default=False, help="for using deep gemm")
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
parser.add_argument("--overlap_allgather", action="store_true")
@ -159,6 +160,7 @@ def main():
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
use_deep_gemm=args.use_deep_gemm,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "gemini_auto":
@ -173,6 +175,7 @@ def main():
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
use_deep_gemm=args.use_deep_gemm,
fp8_communication=args.use_fp8_comm,
)
elif args.plugin == "fsdp":
@ -252,6 +255,7 @@ def main():
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
use_deep_gemm=args.use_deep_gemm,
fp8_communication=args.use_fp8_comm,
scheduler_nodes=scheduler_nodes,
**hybrid_kwargs,
@ -271,6 +275,7 @@ def main():
precision="bf16",
overlap_p2p=args.overlap_p2p,
use_fp8=args.use_fp8,
use_deep_gemm=args.use_deep_gemm,
fp8_communication=args.use_fp8_comm,
)
else:

View File

@ -4,8 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
@ -20,6 +20,12 @@ def new_linear_fp8(x, w, bias=None):
return linear_fp8(x, w, bias)
def new_deepgemm_fp8_gemm(lhs, rhs, out=None):
global TRIGGERED
TRIGGERED = True
return linear_fp8_deep_gemm(lhs, rhs, out)
class FP8TestHook(FP8Hook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
@ -30,13 +36,26 @@ class FP8TestHook(FP8Hook):
return func
D_IN, D_OUT = 16, 32
class DeepGemmTestHook(FP8DeepGemmHook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
if func is linear_fp8_deep_gemm:
global REPLACED
REPLACED = True
return new_deepgemm_fp8_gemm
return func
D_IN, D_OUT = 128, 128
B, S = 2, 64
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_hook():
global REPLACED, TRIGGERED
REPLACED = False
TRIGGERED = False
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
@ -48,3 +67,21 @@ def test_fp8_hook():
assert o.shape == (B, S, D_OUT)
assert REPLACED
assert TRIGGERED
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_deep_gemm_hook():
global REPLACED, TRIGGERED
REPLACED = False
TRIGGERED = False
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
w.__class__ = ColoParameter
w.__init__(w, requires_grad=True)
hook = DeepGemmTestHook()
with ColoParamOpHookManager.use_hooks(hook):
o = F.linear(x, w)
assert o.shape == (S, D_OUT)
assert REPLACED
assert TRIGGERED