From 66018749f3fd79ff92e36b6fa39262f4c6355872 Mon Sep 17 00:00:00 2001 From: BurkeHulk Date: Fri, 12 Jul 2024 15:26:17 +0800 Subject: [PATCH] add fp8_communication flag in the script --- examples/language/bert/finetune.py | 2 ++ examples/language/gpt/hybridparallelism/finetune.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/language/bert/finetune.py b/examples/language/bert/finetune.py index 7e8c07fdc..8a59ab683 100644 --- a/examples/language/bert/finetune.py +++ b/examples/language/bert/finetune.py @@ -190,6 +190,7 @@ def main(): ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "bert": @@ -232,6 +233,7 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs) diff --git a/examples/language/gpt/hybridparallelism/finetune.py b/examples/language/gpt/hybridparallelism/finetune.py index 777d16cb9..9b3a10160 100644 --- a/examples/language/gpt/hybridparallelism/finetune.py +++ b/examples/language/gpt/hybridparallelism/finetune.py @@ -187,6 +187,7 @@ def main(): ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") + parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "gpt2": @@ -225,6 +226,7 @@ def main(): zero_stage=1, precision="fp16", initial_scale=1, + fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs)