[fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark
This commit is contained in:
Hongxin Liu
2024-08-09 14:09:48 +08:00
committed by GitHub
parent 4b9bec8176
commit 8241c0c054
7 changed files with 21 additions and 7 deletions

View File

@@ -99,6 +99,8 @@ def main():
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument("--use_fp8", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@@ -136,6 +138,7 @@ def main():
enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@@ -148,6 +151,7 @@ def main():
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
)
elif args.plugin == "fsdp":
if use_empty_init:
@@ -207,6 +211,8 @@ def main():
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@@ -223,6 +229,7 @@ def main():
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
use_fp8=args.use_fp8,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")