add fused norm (#6038)

This commit is contained in:
Tong Li 2024-08-28 17:12:51 +08:00 committed by GitHub
parent 4a68efb7da
commit 0d3a85d04f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -65,6 +65,7 @@ def train(args) -> None:
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1), enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
@ -74,6 +75,7 @@ def train(args) -> None:
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1), enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "zero2": elif args.plugin == "zero2":
@ -99,6 +101,7 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode, sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage, zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism, enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False, parallel_output=False,