[fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark
This commit is contained in:
Hongxin Liu
2024-08-09 14:09:48 +08:00
committed by GitHub
parent 4b9bec8176
commit 8241c0c054
7 changed files with 21 additions and 7 deletions

View File

@@ -363,6 +363,7 @@ class GeminiPlugin(DPPluginBase):
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
@@ -397,6 +398,7 @@ class GeminiPlugin(DPPluginBase):
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,