[fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark
This commit is contained in:
Hongxin Liu
2024-08-09 14:09:48 +08:00
committed by GitHub
parent 4b9bec8176
commit 8241c0c054
7 changed files with 21 additions and 7 deletions

View File

@@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor import (
distribute_tensor,
@@ -99,6 +100,7 @@ class GeminiDDP(ModelWrapper):
verbose: bool = False,
enable_async_reduce: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@@ -138,6 +140,9 @@ class GeminiDDP(ModelWrapper):
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.hooks = [self.param_op_hook]
if use_fp8:
self.hooks.append(FP8Hook())
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@@ -310,7 +315,7 @@ class GeminiDDP(ModelWrapper):
outputs = self._inference_forward(*args, **kwargs)
else:
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(*self.hooks):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:
@@ -319,7 +324,7 @@ class GeminiDDP(ModelWrapper):
def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference."""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
if not self.scatter_after_inference:
# gather all chunks
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
@@ -372,7 +377,7 @@ class GeminiDDP(ModelWrapper):
def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
loss.backward()
self._post_backward()