mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-04-28 11:45:23 +00:00
[deep_gemm] add deep_gemm hook
This commit is contained in:
parent
9ef62832fd
commit
16e46efd79
@ -435,6 +435,7 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
|
||||||
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
|
||||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
|
use_deep_gemm (bool, optional): Whether to use deep_gemm for fp8 matmul. Defaults to False.
|
||||||
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
"""
|
"""
|
||||||
@ -479,6 +480,7 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_async_reduce: bool = True,
|
enable_async_reduce: bool = True,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -517,6 +519,7 @@ class GeminiPlugin(DPPluginBase):
|
|||||||
enable_async_reduce=enable_async_reduce,
|
enable_async_reduce=enable_async_reduce,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
use_fp8=use_fp8,
|
use_fp8=use_fp8,
|
||||||
|
use_deep_gemm=use_deep_gemm,
|
||||||
)
|
)
|
||||||
self.zero_optim_config = dict(
|
self.zero_optim_config = dict(
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||||
|
@ -33,7 +33,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
|||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule, ZeroBubbleVPipeScheduler
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.quantization.fp8_hook import FP8Hook
|
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
|
from colossalai.shardformer.layer.utils import SeqParallelUtils, is_share_sp_tp
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
@ -70,6 +70,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||||||
custom_policy: Policy,
|
custom_policy: Policy,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.stage_manager = shard_config.pipeline_stage_manager
|
self.stage_manager = shard_config.pipeline_stage_manager
|
||||||
self.shard_config = shard_config
|
self.shard_config = shard_config
|
||||||
@ -80,6 +81,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||||||
self.require_grad_sync = True
|
self.require_grad_sync = True
|
||||||
self.overlap_allgather = overlap_allgather
|
self.overlap_allgather = overlap_allgather
|
||||||
self.use_fp8 = use_fp8
|
self.use_fp8 = use_fp8
|
||||||
|
self.use_deep_gemm = use_deep_gemm
|
||||||
|
|
||||||
shardformer = ShardFormer(shard_config)
|
shardformer = ShardFormer(shard_config)
|
||||||
if custom_policy is not None:
|
if custom_policy is not None:
|
||||||
@ -119,7 +121,10 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.op_hooks = []
|
self.op_hooks = []
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
self.op_hooks.append(FP8Hook())
|
if use_deep_gemm:
|
||||||
|
self.op_hooks.append(FP8DeepGemmHook())
|
||||||
|
else:
|
||||||
|
self.op_hooks.append(FP8Hook())
|
||||||
if overlap_allgather:
|
if overlap_allgather:
|
||||||
self.op_hooks.append(ZeroOpHook())
|
self.op_hooks.append(ZeroOpHook())
|
||||||
if use_fp8 or overlap_allgather:
|
if use_fp8 or overlap_allgather:
|
||||||
@ -1044,6 +1049,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
inner_ring_size: int = None,
|
inner_ring_size: int = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -1097,6 +1103,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
self.enable_jit_fused = enable_jit_fused
|
self.enable_jit_fused = enable_jit_fused
|
||||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||||
self.use_fp8 = use_fp8
|
self.use_fp8 = use_fp8
|
||||||
|
self.use_deep_gemm = use_deep_gemm
|
||||||
if dp_outside:
|
if dp_outside:
|
||||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||||
@ -1323,6 +1330,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||||||
custom_policy=self.custom_policy,
|
custom_policy=self.custom_policy,
|
||||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
|
use_deep_gemm=self.use_deep_gemm,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if zero_stage == 0:
|
if zero_stage == 0:
|
||||||
|
@ -37,7 +37,7 @@ from colossalai.interface.optimizer import DistributedOptim
|
|||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
from colossalai.quantization.fp8_hook import FP8Hook
|
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
@ -72,6 +72,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
cast_inputs: bool = True,
|
cast_inputs: bool = True,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.dtype = None
|
self.dtype = None
|
||||||
@ -92,7 +93,10 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||||||
if overlap_allgather:
|
if overlap_allgather:
|
||||||
self.op_hooks.append(ZeroOpHook())
|
self.op_hooks.append(ZeroOpHook())
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
self.op_hooks.append(FP8Hook())
|
if use_deep_gemm:
|
||||||
|
self.op_hooks.append(FP8DeepGemmHook())
|
||||||
|
else:
|
||||||
|
self.op_hooks.append(FP8Hook())
|
||||||
if overlap_allgather or use_fp8:
|
if overlap_allgather or use_fp8:
|
||||||
for p in module.parameters():
|
for p in module.parameters():
|
||||||
if p.requires_grad and type(p) is not ColoParameter:
|
if p.requires_grad and type(p) is not ColoParameter:
|
||||||
@ -400,6 +404,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
|
cpu_offload (bool, optional): whether to offload grad, master weight and optimizer state to cpu. Defaults to False.
|
||||||
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
verbose (bool, optional): verbose mode. Debug info including grad overflow will be printed. Defaults to False.
|
||||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
|
use_deep_gemm (bool, optional): Whether to use deep_gemm matmul. Defaults to False.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
|
extra_dp_size (int, optional): The number of extra data parallel groups. Defaults to 1.
|
||||||
"""
|
"""
|
||||||
@ -427,6 +432,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
cast_inputs: bool = True,
|
cast_inputs: bool = True,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
extra_dp_size: int = 1,
|
extra_dp_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -466,6 +472,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
self.cast_inputs = cast_inputs
|
self.cast_inputs = cast_inputs
|
||||||
|
|
||||||
self.use_fp8 = use_fp8
|
self.use_fp8 = use_fp8
|
||||||
|
self.use_deep_gemm = use_deep_gemm
|
||||||
# set class name with stage, for better error message
|
# set class name with stage, for better error message
|
||||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||||
|
|
||||||
@ -585,6 +592,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||||
cast_inputs=self.cast_inputs,
|
cast_inputs=self.cast_inputs,
|
||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
|
use_deep_gemm=self.use_deep_gemm,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Support Galore + ZeRO
|
# TODO: Support Galore + ZeRO
|
||||||
|
@ -171,6 +171,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||||
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
|
overlap_p2p (bool, optional): Whether to overlap the p2p communication in pipeline parallelism.
|
||||||
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
use_fp8 (bool, optional): Whether to enable fp8 mixed precision training. Defaults to False.
|
||||||
|
use_deep_gemm (bool, optional): Whether to use deep gemm for fp8 training. Defaults to False.
|
||||||
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
fp8_communication (bool, optional): Whether to enable fp8 communication. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -222,6 +223,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
if overlap_communication or zero_stage == 2:
|
if overlap_communication or zero_stage == 2:
|
||||||
@ -359,6 +361,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
self.mixed_dp_group = self.dp_group
|
self.mixed_dp_group = self.dp_group
|
||||||
|
|
||||||
self.use_fp8 = use_fp8
|
self.use_fp8 = use_fp8
|
||||||
|
self.use_deep_gemm = use_deep_gemm
|
||||||
|
|
||||||
self.shard_config = ShardConfig(
|
self.shard_config = ShardConfig(
|
||||||
tensor_parallel_process_group=self.tp_group,
|
tensor_parallel_process_group=self.tp_group,
|
||||||
@ -465,6 +468,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||||||
ddp_config=self.ddp_config,
|
ddp_config=self.ddp_config,
|
||||||
custom_policy=self.custom_policy,
|
custom_policy=self.custom_policy,
|
||||||
use_fp8=self.use_fp8,
|
use_fp8=self.use_fp8,
|
||||||
|
use_deep_gemm=self.use_deep_gemm,
|
||||||
)
|
)
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
if self.zero_stage == 0:
|
if self.zero_stage == 0:
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from colossalai.quantization.fp8 import linear_fp8
|
from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
|
|
||||||
|
|
||||||
@ -21,3 +21,23 @@ class FP8Hook(ColoParamOpHook):
|
|||||||
if func is F.linear:
|
if func is F.linear:
|
||||||
return linear_fp8
|
return linear_fp8
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
class FP8DeepGemmHook(ColoParamOpHook):
|
||||||
|
|
||||||
|
def pre_forward(self, params) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_forward(self, params) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def pre_backward(self, params) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def post_backward(self, params) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def rewrite_op(self, func):
|
||||||
|
if func is F.linear:
|
||||||
|
return linear_fp8_deep_gemm
|
||||||
|
return func
|
||||||
|
@ -15,7 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
|
|||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
from colossalai.lazy import LazyTensor
|
from colossalai.lazy import LazyTensor
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.quantization.fp8_hook import FP8Hook
|
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.tensor.d_tensor import (
|
from colossalai.tensor.d_tensor import (
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
@ -101,6 +101,7 @@ class GeminiDDP(ModelWrapper):
|
|||||||
enable_async_reduce: bool = True,
|
enable_async_reduce: bool = True,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
use_fp8: bool = False,
|
use_fp8: bool = False,
|
||||||
|
use_deep_gemm: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||||
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||||
@ -142,7 +143,10 @@ class GeminiDDP(ModelWrapper):
|
|||||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||||
self.hooks = [self.param_op_hook]
|
self.hooks = [self.param_op_hook]
|
||||||
if use_fp8:
|
if use_fp8:
|
||||||
self.hooks.append(FP8Hook())
|
if use_deep_gemm:
|
||||||
|
self.hooks.append(FP8DeepGemmHook())
|
||||||
|
else:
|
||||||
|
self.hooks.append(FP8Hook())
|
||||||
self.fp32_params: List[torch.Tensor] = list()
|
self.fp32_params: List[torch.Tensor] = list()
|
||||||
self.fp16_params: List[ColoParameter] = list()
|
self.fp16_params: List[ColoParameter] = list()
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||||
|
@ -63,7 +63,7 @@ However, there are other operations, like reductions, which require the dynamic
|
|||||||
|
|
||||||
We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`.
|
We supported three AMP training methods and allowed the user to train with AMP with no code. If you want to train with amp, just assign `mixed_precision` with `fp16` when you instantiate the `Booster`. Next we will support `bf16`.
|
||||||
|
|
||||||
Currently we only support `fp8` mixed precision training for the `Linear` layer. Please specify the `use_fp8` parameter when create the plugin object.
|
Currently we only support `fp8` mixed precision training for the `Linear` layer, please specify the `use_fp8` parameter when create the plugin object. `deep_gemm` fp8 matmul is adopted which can be enabled by specifying `use_deep_gemm`.
|
||||||
|
|
||||||
To reduce the communication volume inter nodes in low-bandwidth scenarios, we support FP8 communication compression. Please specify the `fp8_communication` parameter when create the plugin object.
|
To reduce the communication volume inter nodes in low-bandwidth scenarios, we support FP8 communication compression. Please specify the `fp8_communication` parameter when create the plugin object.
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ AMP 代表自动混合精度训练。
|
|||||||
|
|
||||||
我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数; 后续将会拓展`bf16`.
|
我们支持三种 AMP 训练方法,并允许用户在没有改变代码的情况下使用 AMP 进行训练。booster 支持 amp 特性注入,如果您要使用混合精度训练,则在创建 booster 实例时指定`mixed_precision`参数; 后续将会拓展`bf16`.
|
||||||
|
|
||||||
我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数。
|
我们目前只支持`Linear`层的`fp8`混合精度训练,如果您需要使用,请在创建 plugin实例时指定`use_fp8`参数,`deep_gemm`fp8矩阵乘法适配请指定`use_deep_gemm`参数。
|
||||||
|
|
||||||
为了减少低带宽场景下多机之间的通讯负载,我们还支持了FP8通讯。如果您需要使用,请在创建 plugin实例时指定`fp8_communication`参数。
|
为了减少低带宽场景下多机之间的通讯负载,我们还支持了FP8通讯。如果您需要使用,请在创建 plugin实例时指定`fp8_communication`参数。
|
||||||
|
|
||||||
|
@ -107,6 +107,7 @@ def main():
|
|||||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||||
parser.add_argument("--no_cache", action="store_true")
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
|
parser.add_argument("--use_deep_gemm", action="store_true", default=False, help="for using deep gemm")
|
||||||
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
parser.add_argument("--use_fp8", action="store_true", default=False, help="for using fp8 linear")
|
||||||
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
|
parser.add_argument("--overlap_p2p", action="store_true", default=True, help="for using overlap p2p")
|
||||||
parser.add_argument("--overlap_allgather", action="store_true")
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
@ -159,6 +160,7 @@ def main():
|
|||||||
max_prefetch=args.prefetch_num,
|
max_prefetch=args.prefetch_num,
|
||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
use_deep_gemm=args.use_deep_gemm,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
@ -173,6 +175,7 @@ def main():
|
|||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
use_deep_gemm=args.use_deep_gemm,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
elif args.plugin == "fsdp":
|
elif args.plugin == "fsdp":
|
||||||
@ -252,6 +255,7 @@ def main():
|
|||||||
enable_metadata_cache=not args.no_cache,
|
enable_metadata_cache=not args.no_cache,
|
||||||
overlap_allgather=args.overlap_allgather,
|
overlap_allgather=args.overlap_allgather,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
use_deep_gemm=args.use_deep_gemm,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
scheduler_nodes=scheduler_nodes,
|
scheduler_nodes=scheduler_nodes,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
@ -271,6 +275,7 @@ def main():
|
|||||||
precision="bf16",
|
precision="bf16",
|
||||||
overlap_p2p=args.overlap_p2p,
|
overlap_p2p=args.overlap_p2p,
|
||||||
use_fp8=args.use_fp8,
|
use_fp8=args.use_fp8,
|
||||||
|
use_deep_gemm=args.use_deep_gemm,
|
||||||
fp8_communication=args.use_fp8_comm,
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.quantization.fp8 import linear_fp8
|
from colossalai.quantization.fp8 import linear_fp8, linear_fp8_deep_gemm
|
||||||
from colossalai.quantization.fp8_hook import FP8Hook
|
from colossalai.quantization.fp8_hook import FP8DeepGemmHook, FP8Hook
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
@ -20,6 +20,12 @@ def new_linear_fp8(x, w, bias=None):
|
|||||||
return linear_fp8(x, w, bias)
|
return linear_fp8(x, w, bias)
|
||||||
|
|
||||||
|
|
||||||
|
def new_deepgemm_fp8_gemm(lhs, rhs, out=None):
|
||||||
|
global TRIGGERED
|
||||||
|
TRIGGERED = True
|
||||||
|
return linear_fp8_deep_gemm(lhs, rhs, out)
|
||||||
|
|
||||||
|
|
||||||
class FP8TestHook(FP8Hook):
|
class FP8TestHook(FP8Hook):
|
||||||
def rewrite_op(self, func):
|
def rewrite_op(self, func):
|
||||||
func = super().rewrite_op(func)
|
func = super().rewrite_op(func)
|
||||||
@ -30,13 +36,26 @@ class FP8TestHook(FP8Hook):
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
D_IN, D_OUT = 16, 32
|
class DeepGemmTestHook(FP8DeepGemmHook):
|
||||||
|
def rewrite_op(self, func):
|
||||||
|
func = super().rewrite_op(func)
|
||||||
|
if func is linear_fp8_deep_gemm:
|
||||||
|
global REPLACED
|
||||||
|
REPLACED = True
|
||||||
|
return new_deepgemm_fp8_gemm
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
D_IN, D_OUT = 128, 128
|
||||||
B, S = 2, 64
|
B, S = 2, 64
|
||||||
DTYPE = torch.bfloat16
|
DTYPE = torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||||
def test_fp8_hook():
|
def test_fp8_hook():
|
||||||
|
global REPLACED, TRIGGERED
|
||||||
|
REPLACED = False
|
||||||
|
TRIGGERED = False
|
||||||
# create tensors
|
# create tensors
|
||||||
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
||||||
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
@ -48,3 +67,21 @@ def test_fp8_hook():
|
|||||||
assert o.shape == (B, S, D_OUT)
|
assert o.shape == (B, S, D_OUT)
|
||||||
assert REPLACED
|
assert REPLACED
|
||||||
assert TRIGGERED
|
assert TRIGGERED
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||||
|
def test_fp8_deep_gemm_hook():
|
||||||
|
global REPLACED, TRIGGERED
|
||||||
|
REPLACED = False
|
||||||
|
TRIGGERED = False
|
||||||
|
# create tensors
|
||||||
|
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
||||||
|
x = torch.rand(S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
|
w.__class__ = ColoParameter
|
||||||
|
w.__init__(w, requires_grad=True)
|
||||||
|
hook = DeepGemmTestHook()
|
||||||
|
with ColoParamOpHookManager.use_hooks(hook):
|
||||||
|
o = F.linear(x, w)
|
||||||
|
assert o.shape == (S, D_OUT)
|
||||||
|
assert REPLACED
|
||||||
|
assert TRIGGERED
|
||||||
|
Loading…
Reference in New Issue
Block a user