mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-15 14:43:13 +00:00
add fp8_communication flag in the script
This commit is contained in:
parent
e88190184a
commit
66018749f3
@ -190,6 +190,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||||
|
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_type == "bert":
|
if args.model_type == "bert":
|
||||||
@ -232,6 +233,7 @@ def main():
|
|||||||
zero_stage=1,
|
zero_stage=1,
|
||||||
precision="fp16",
|
precision="fp16",
|
||||||
initial_scale=1,
|
initial_scale=1,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
|
|
||||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
@ -187,6 +187,7 @@ def main():
|
|||||||
)
|
)
|
||||||
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached")
|
||||||
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context")
|
||||||
|
parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model_type == "gpt2":
|
if args.model_type == "gpt2":
|
||||||
@ -225,6 +226,7 @@ def main():
|
|||||||
zero_stage=1,
|
zero_stage=1,
|
||||||
precision="fp16",
|
precision="fp16",
|
||||||
initial_scale=1,
|
initial_scale=1,
|
||||||
|
fp8_communication=args.use_fp8_comm,
|
||||||
)
|
)
|
||||||
|
|
||||||
booster = Booster(plugin=plugin, **booster_kwargs)
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||||||
|
Loading…
Reference in New Issue
Block a user