diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 712703b45..b71203518 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -79,7 +79,7 @@ def main(): parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False) parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False) - + parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") args = parser.parse_args() colossalai.launch_from_torch() @@ -114,7 +114,7 @@ def main(): extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.xformers, - max_prefetch=10, + max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, ) elif args.plugin == "gemini_auto": @@ -125,6 +125,8 @@ def main(): tp_size=args.tp, extra_dp_size=args.extra_dp, enable_fused_normalization=torch.cuda.is_available(), + max_prefetch=args.prefetch_num, + enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, ) elif args.plugin == "fsdp":