mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 18:19:58 +00:00
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility
This commit is contained in:
@@ -40,6 +40,7 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .fp8_hook import FP8Hook
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
@@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
overlap_allgather: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
@@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
self.use_dpp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.use_fp8 = use_fp8
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
@@ -112,8 +115,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
self.op_hooks = []
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather or use_fp8:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
@@ -223,7 +230,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
wait_all_gather_handle(p)
|
||||
|
||||
def _wait_all_gather(self):
|
||||
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
return (
|
||||
ColoParamOpHookManager.use_hooks(*self.op_hooks)
|
||||
if (self.overlap_allgather or self.use_fp8)
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
@@ -1019,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1063,6 +1075,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.use_fp8 = use_fp8
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
@@ -1243,6 +1256,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
|
Reference in New Issue
Block a user