mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-27 19:36:13 +00:00
[deep_gemm] add deep_gemm hook
This commit is contained in:
parent
9ef62832fd
commit
16e46efd79
@ -435,6 +435,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
use_deep_gemm (bool, optional): Whether to use deep_gemm for fp8 matmul. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
"""
|
||||
@ -479,6 +480,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_jit_fused: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
use_fp8: bool = False,
|
||||
use_deep_gemm: bool = False,
|
||||
verbose: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
@ -517,6 +519,7 @@ class GeminiPlugin(DPPluginBase):
|
||||
enable_async_reduce=enable_async_reduce,
|
||||
fp8_communication=fp8_communication,
|
||||
use_fp8=use_fp8,
|
||||
use_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
@ -33,7 +33,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
@ -70,6 +70,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
custom_policy: Policy,
|
||||
overlap_allgather: bool = False,
|
||||
use_fp8: bool = False,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
@ -80,6 +81,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
self.require_grad_sync = True
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.use_fp8 = use_fp8
|
||||
self.use_deep_gemm = use_deep_gemm
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
@ -119,7 +121,10 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
super().__init__(module)
|
||||
self.op_hooks = []
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if use_deep_gemm:
|
||||
self.op_hooks.append(FP8DeepGemmHook())
|
||||
else:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather:
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8 or overlap_allgather:
|
||||
@ -1044,6 +1049,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
use_deep_gemm: bool = False,
|
||||
inner_ring_size: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -1097,6 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.use_fp8 = use_fp8
|
||||
self.use_deep_gemm = use_deep_gemm
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
@ -1323,6 +1330,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
custom_policy=self.custom_policy,
|
||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
use_fp8=self.use_fp8,
|
||||
use_deep_gemm=self.use_deep_gemm,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
|
@ -37,7 +37,7 @@ from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
@ -72,6 +72,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
overlap_allgather: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
use_fp8: bool = False,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
@ -92,7 +93,10 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
if overlap_allgather:
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if use_deep_gemm:
|
||||
self.op_hooks.append(FP8DeepGemmHook())
|
||||
else:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather or use_fp8:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
@ -400,6 +404,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
|
||||
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
use_deep_gemm (bool, optional): Whether to use deep_gemm matmul. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
|
||||
"""
|
||||
@ -427,6 +432,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
use_deep_gemm: bool = False,
|
||||
extra_dp_size: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@ -466,6 +472,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
self.cast_inputs = cast_inputs
|
||||
|
||||
self.use_fp8 = use_fp8
|
||||
self.use_deep_gemm = use_deep_gemm
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
||||
@ -585,6 +592,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||
cast_inputs=self.cast_inputs,
|
||||
use_fp8=self.use_fp8,
|
||||
use_deep_gemm=self.use_deep_gemm,
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
|
@ -171,6 +171,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
|
||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||
use_deep_gemm (bool, optional): Whether to use deep gemm for fp8 training. Defaults to False.
|
||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||
"""
|
||||
|
||||
@ -222,6 +223,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> None:
|
||||
self.logger = get_dist_logger()
|
||||
if overlap_communication or zero_stage == 2:
|
||||
@ -359,6 +361,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
self.mixed_dp_group = self.dp_group
|
||||
|
||||
self.use_fp8 = use_fp8
|
||||
self.use_deep_gemm = use_deep_gemm
|
||||
|
||||
self.shard_config = ShardConfig(
|
||||
tensor_parallel_process_group=self.tp_group,
|
||||
@ -465,6 +468,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
use_fp8=self.use_fp8,
|
||||
use_deep_gemm=self.use_deep_gemm,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if self.zero_stage == 0:
|
||||
|
@ -1,6 +1,6 @@
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
|
||||
@ -21,3 +21,23 @@ class FP8Hook(ColoParamOpHook):
|
||||
if func is F.linear:
|
||||
return linear_fp8
|
||||
return func
|
||||
|
||||
|
||||
class FP8DeepGemmHook(ColoParamOpHook):
|
||||
|
||||
def pre_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def pre_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def rewrite_op(self, func):
|
||||
if func is F.linear:
|
||||
return linear_fp8_deep_gemm
|
||||
return func
|
||||
|
@ -15,7 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
@ -101,6 +101,7 @@ class GeminiDDP(ModelWrapper):
|
||||
enable_async_reduce: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
use_deep_gemm: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||
@ -142,7 +143,10 @@ class GeminiDDP(ModelWrapper):
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.hooks = [self.param_op_hook]
|
||||
if use_fp8:
|
||||
self.hooks.append(FP8Hook())
|
||||
if use_deep_gemm:
|
||||
self.hooks.append(FP8DeepGemmHook())
|
||||
else:
|
||||
self.hooks.append(FP8Hook())
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
|
@ -63,7 +63,7 @@ However, there are other operations, like reductions, which require the dynamic
|
||||
|
||||
We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`.
|
||||
|
||||
Currently we only support `fp8` mixed precision training for the `Linear` layer. Please specify the `use_fp8` parameter when create the plugin object.
|
||||
Currently we only support `fp8` mixed precision training for the `Linear` layer, please specify the `use_fp8` parameter when create the plugin object. `deep_gemm` fp8 matmul is adopted which can be enabled by specifying `use_deep_gemm`.
|
||||
|
||||
To reduce the communication volume inter nodes in low-bandwidth scenarios, we support FP8 communication compression. Please specify the `fp8_communication` parameter when create the plugin object.
|
||||
|
||||
|
@ -59,7 +59,7 @@ AMP 代表自动混合精度训练。
|
||||
|
||||
我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数; 后续将会拓展`bf16`.
|
||||
|
||||
我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数。
|
||||
我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数,`deep_gemm`fp8矩阵乘法适配请指定`use_deep_gemm`参数。
|
||||
|
||||
为了减少低带宽场景下多机之间的通讯负载,我们还支持了FP8通讯。如果您需要使用,请在创建 plugin实例时指定`fp8_communication`参数。
|
||||
|
||||
|
@ -107,6 +107,7 @@ def main():
|
||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||
parser.add_argument("--use_deep_gemm", action="store_true", default=False, help="for using deep gemm")
|
||||
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
@ -159,6 +160,7 @@ def main():
|
||||
max_prefetch=args.prefetch_num,
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
use_fp8=args.use_fp8,
|
||||
use_deep_gemm=args.use_deep_gemm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
@ -173,6 +175,7 @@ def main():
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
enable_flash_attention=args.xformers,
|
||||
use_fp8=args.use_fp8,
|
||||
use_deep_gemm=args.use_deep_gemm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
@ -252,6 +255,7 @@ def main():
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
use_deep_gemm=args.use_deep_gemm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
scheduler_nodes=scheduler_nodes,
|
||||
**hybrid_kwargs,
|
||||
@ -271,6 +275,7 @@ def main():
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap_p2p,
|
||||
use_fp8=args.use_fp8,
|
||||
use_deep_gemm=args.use_deep_gemm,
|
||||
fp8_communication=args.use_fp8_comm,
|
||||
)
|
||||
else:
|
||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm
|
||||
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
@ -20,6 +20,12 @@ def new_linear_fp8(x, w, bias=None):
|
||||
return linear_fp8(x, w, bias)
|
||||
|
||||
|
||||
def new_deepgemm_fp8_gemm(lhs, rhs, out=None):
|
||||
global TRIGGERED
|
||||
TRIGGERED = True
|
||||
return linear_fp8_deep_gemm(lhs, rhs, out)
|
||||
|
||||
|
||||
class FP8TestHook(FP8Hook):
|
||||
def rewrite_op(self, func):
|
||||
func = super().rewrite_op(func)
|
||||
@ -30,13 +36,26 @@ class FP8TestHook(FP8Hook):
|
||||
return func
|
||||
|
||||
|
||||
D_IN, D_OUT = 16, 32
|
||||
class DeepGemmTestHook(FP8DeepGemmHook):
|
||||
def rewrite_op(self, func):
|
||||
func = super().rewrite_op(func)
|
||||
if func is linear_fp8_deep_gemm:
|
||||
global REPLACED
|
||||
REPLACED = True
|
||||
return new_deepgemm_fp8_gemm
|
||||
return func
|
||||
|
||||
|
||||
D_IN, D_OUT = 128, 128
|
||||
B, S = 2, 64
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||
def test_fp8_hook():
|
||||
global REPLACED, TRIGGERED
|
||||
REPLACED = False
|
||||
TRIGGERED = False
|
||||
# create tensors
|
||||
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
||||
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
@ -48,3 +67,21 @@ def test_fp8_hook():
|
||||
assert o.shape == (B, S, D_OUT)
|
||||
assert REPLACED
|
||||
assert TRIGGERED
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||
def test_fp8_deep_gemm_hook():
|
||||
global REPLACED, TRIGGERED
|
||||
REPLACED = False
|
||||
TRIGGERED = False
|
||||
# create tensors
|
||||
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
||||
x = torch.rand(S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
w.__class__ = ColoParameter
|
||||
w.__init__(w, requires_grad=True)
|
||||
hook = DeepGemmTestHook()
|
||||
with ColoParamOpHookManager.use_hooks(hook):
|
||||
o = F.linear(x, w)
|
||||
assert o.shape == (S, D_OUT)
|
||||
assert REPLACED
|
||||
assert TRIGGERED
|
||||
|
Loading…
Reference in New Issue
Block a user