mirror of
				https://github.com/hpcaitech/ColossalAI.git
				synced 2025-11-04 07:58:42 +00:00 
			
		
		
		
	[fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility
This commit is contained in:
		
							
								
								
									
										23
									
								
								colossalai/booster/plugin/fp8_hook.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								colossalai/booster/plugin/fp8_hook.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,23 @@
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from colossalai.quantization.fp8 import linear_fp8
 | 
			
		||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FP8Hook(ColoParamOpHook):
 | 
			
		||||
    def pre_forward(self, params) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def post_forward(self, params) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def pre_backward(self, params) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def post_backward(self, params) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def rewrite_op(self, func):
 | 
			
		||||
        if func is F.linear:
 | 
			
		||||
            return linear_fp8
 | 
			
		||||
        return func
 | 
			
		||||
@@ -40,6 +40,7 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
 | 
			
		||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
 | 
			
		||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
 | 
			
		||||
 | 
			
		||||
from .fp8_hook import FP8Hook
 | 
			
		||||
from .pp_plugin_base import PipelinePluginBase
 | 
			
		||||
 | 
			
		||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
 | 
			
		||||
@@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
 | 
			
		||||
        ddp_config: dict,
 | 
			
		||||
        custom_policy: Policy,
 | 
			
		||||
        overlap_allgather: bool = False,
 | 
			
		||||
        use_fp8: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        self.stage_manager = shard_config.pipeline_stage_manager
 | 
			
		||||
        self.shard_config = shard_config
 | 
			
		||||
@@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
 | 
			
		||||
        self.use_dpp = use_ddp
 | 
			
		||||
        self.require_grad_sync = True
 | 
			
		||||
        self.overlap_allgather = overlap_allgather
 | 
			
		||||
        self.use_fp8 = use_fp8
 | 
			
		||||
 | 
			
		||||
        shardformer = ShardFormer(shard_config)
 | 
			
		||||
        if custom_policy is not None:
 | 
			
		||||
@@ -112,8 +115,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
 | 
			
		||||
            module = DDP(module, process_group=dp_group, **ddp_config)
 | 
			
		||||
 | 
			
		||||
        super().__init__(module)
 | 
			
		||||
        self.op_hooks = []
 | 
			
		||||
        if overlap_allgather:
 | 
			
		||||
            self.op_hook = ZeroOpHook()
 | 
			
		||||
            self.op_hooks.append(ZeroOpHook())
 | 
			
		||||
        if use_fp8:
 | 
			
		||||
            self.op_hooks.append(FP8Hook())
 | 
			
		||||
        if overlap_allgather or use_fp8:
 | 
			
		||||
            for p in module.parameters():
 | 
			
		||||
                if p.requires_grad and type(p) is not ColoParameter:
 | 
			
		||||
                    p.__class__ = ColoParameter
 | 
			
		||||
@@ -223,7 +230,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
 | 
			
		||||
            wait_all_gather_handle(p)
 | 
			
		||||
 | 
			
		||||
    def _wait_all_gather(self):
 | 
			
		||||
        return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
 | 
			
		||||
        return (
 | 
			
		||||
            ColoParamOpHookManager.use_hooks(*self.op_hooks)
 | 
			
		||||
            if (self.overlap_allgather or self.use_fp8)
 | 
			
		||||
            else nullcontext()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_param_info(optim: Optimizer):
 | 
			
		||||
@@ -1019,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase):
 | 
			
		||||
        overlap_p2p: bool = True,
 | 
			
		||||
        overlap_allgather: bool = False,
 | 
			
		||||
        fp8_communication: bool = False,
 | 
			
		||||
        use_fp8: bool = False,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
@@ -1063,6 +1075,7 @@ class HybridParallelPlugin(PipelinePluginBase):
 | 
			
		||||
        self.enable_flash_attention = enable_flash_attention
 | 
			
		||||
        self.enable_jit_fused = enable_jit_fused
 | 
			
		||||
        self.enable_sequence_parallelism = enable_sequence_parallelism
 | 
			
		||||
        self.use_fp8 = use_fp8
 | 
			
		||||
        if dp_outside:
 | 
			
		||||
            self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
 | 
			
		||||
            self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
 | 
			
		||||
@@ -1243,6 +1256,7 @@ class HybridParallelPlugin(PipelinePluginBase):
 | 
			
		||||
                ddp_config=self.ddp_config,
 | 
			
		||||
                custom_policy=self.custom_policy,
 | 
			
		||||
                overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
 | 
			
		||||
                use_fp8=self.use_fp8,
 | 
			
		||||
            )
 | 
			
		||||
        if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
 | 
			
		||||
            if zero_stage == 0:
 | 
			
		||||
 
 | 
			
		||||
@@ -431,7 +431,8 @@ class _LinearFp8(torch.autograd.Function):
 | 
			
		||||
        if bias is not None:
 | 
			
		||||
            assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
 | 
			
		||||
        # ensure x and w are row-major
 | 
			
		||||
        assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous."
 | 
			
		||||
        x = x.contiguous()
 | 
			
		||||
        w = w.contiguous()
 | 
			
		||||
        ctx.x_shape = x.shape
 | 
			
		||||
        ctx.has_bias = bias is not None
 | 
			
		||||
        ctx.out_dtype = x.dtype
 | 
			
		||||
 
 | 
			
		||||
@@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
 | 
			
		||||
                with torch._C.DisableTorchFunction():
 | 
			
		||||
                    new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
 | 
			
		||||
                args, kwargs = replace_args(args, kwargs, new_args)
 | 
			
		||||
                with torch._C.DisableTorchFunction():
 | 
			
		||||
                    func = ColoParamOpHookManager.rewrite_op(func)
 | 
			
		||||
                ret = super().__torch_function__(func, types, args, kwargs)
 | 
			
		||||
                with torch._C.DisableTorchFunction():
 | 
			
		||||
                    ret = ColoParamOpHookManager.post_op(params, ret)
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,9 @@ class ColoParamOpHook(ABC):
 | 
			
		||||
    def post_backward(self, params: List[torch.Tensor]) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def rewrite_op(self, func) -> Any:
 | 
			
		||||
        return func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ColoParamOpHookManager:
 | 
			
		||||
    """
 | 
			
		||||
@@ -101,6 +104,12 @@ class ColoParamOpHookManager:
 | 
			
		||||
    def has_hook() -> bool:
 | 
			
		||||
        return len(ColoParamOpHookManager.hooks) > 0
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def rewrite_op(func) -> Any:
 | 
			
		||||
        for hook in ColoParamOpHookManager.hooks:
 | 
			
		||||
            func = hook.rewrite_op(func)
 | 
			
		||||
        return func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PreFwdPostBwd(torch.autograd.Function):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										50
									
								
								tests/test_fp8/test_fp8_hook.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								tests/test_fp8/test_fp8_hook.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,50 @@
 | 
			
		||||
import pytest
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn as nn
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from colossalai.accelerator import get_accelerator
 | 
			
		||||
from colossalai.booster.plugin.fp8_hook import FP8Hook
 | 
			
		||||
from colossalai.quantization.fp8 import linear_fp8
 | 
			
		||||
from colossalai.tensor.colo_parameter import ColoParameter
 | 
			
		||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
 | 
			
		||||
from colossalai.utils import get_current_device
 | 
			
		||||
 | 
			
		||||
REPLACED = False
 | 
			
		||||
TRIGGERED = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def new_linear_fp8(x, w, bias=None):
 | 
			
		||||
    global TRIGGERED
 | 
			
		||||
    TRIGGERED = True
 | 
			
		||||
    return linear_fp8(x, w, bias)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FP8TestHook(FP8Hook):
 | 
			
		||||
    def rewrite_op(self, func):
 | 
			
		||||
        func = super().rewrite_op(func)
 | 
			
		||||
        if func is linear_fp8:
 | 
			
		||||
            global REPLACED
 | 
			
		||||
            REPLACED = True
 | 
			
		||||
            return new_linear_fp8
 | 
			
		||||
        return func
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
D_IN, D_OUT = 16, 32
 | 
			
		||||
B, S = 2, 64
 | 
			
		||||
DTYPE = torch.bfloat16
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
 | 
			
		||||
def test_fp8_hook():
 | 
			
		||||
    # create tensors
 | 
			
		||||
    w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
 | 
			
		||||
    x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
 | 
			
		||||
    w.__class__ = ColoParameter
 | 
			
		||||
    w.__init__(w, requires_grad=True)
 | 
			
		||||
    hook = FP8TestHook()
 | 
			
		||||
    with ColoParamOpHookManager.use_hooks(hook):
 | 
			
		||||
        o = F.linear(x, w)
 | 
			
		||||
    assert o.shape == (B, S, D_OUT)
 | 
			
		||||
    assert REPLACED
 | 
			
		||||
    assert TRIGGERED
 | 
			
		||||
		Reference in New Issue
	
	Block a user