mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 13:41:43 +00:00
add fused norm (#6038)
This commit is contained in:
parent
4a68efb7da
commit
0d3a85d04f
@ -65,6 +65,7 @@ def train(args) -> None:
|
|||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
@ -74,6 +75,7 @@ def train(args) -> None:
|
|||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
||||||
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "zero2":
|
elif args.plugin == "zero2":
|
||||||
@ -99,6 +101,7 @@ def train(args) -> None:
|
|||||||
sequence_parallelism_mode=args.sp_mode,
|
sequence_parallelism_mode=args.sp_mode,
|
||||||
zero_stage=args.zero_stage,
|
zero_stage=args.zero_stage,
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
|
enable_fused_normalization=torch.cuda.is_available(),
|
||||||
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
enable_sequence_parallelism=args.enable_sequence_parallelism,
|
||||||
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
|
Loading…
Reference in New Issue
Block a user