mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
Merge branch 'main' of github.com:hpcaitech/ColossalAI into prefetch
This commit is contained in:
@@ -78,6 +78,8 @@ def main():
|
||||
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
|
||||
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
|
||||
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
|
||||
parser.add_argument("--disable-async-reduce", action="store_true", help="Customize checkpoint", default=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
@@ -113,6 +115,7 @@ def main():
|
||||
enable_fused_normalization=torch.cuda.is_available(),
|
||||
enable_flash_attention=args.xformers,
|
||||
max_prefetch=10,
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
Reference in New Issue
Block a user