diff --git a/applications/Colossal-LLaMA/train.py b/applications/Colossal-LLaMA/train.py index 112a1e0dc..db23275e4 100644 --- a/applications/Colossal-LLaMA/train.py +++ b/applications/Colossal-LLaMA/train.py @@ -65,6 +65,7 @@ def train(args) -> None: initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "gemini_auto": @@ -74,6 +75,7 @@ def train(args) -> None: initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=(args.accumulation_steps > 1), + enable_fused_normalization=torch.cuda.is_available(), enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "zero2": @@ -99,6 +101,7 @@ def train(args) -> None: sequence_parallelism_mode=args.sp_mode, zero_stage=args.zero_stage, enable_flash_attention=args.use_flash_attn, + enable_fused_normalization=torch.cuda.is_available(), enable_sequence_parallelism=args.enable_sequence_parallelism, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, parallel_output=False,