add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

This commit is contained in:
YeAnbang
2024-07-10 10:17:08 +00:00
parent 16f3451fe2
commit d888c3787c
13 changed files with 1175 additions and 26 deletions

View File

@@ -122,13 +122,11 @@ def train(args):
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
if args.grad_checkpoint and args.lora_rank == 0:
# lora layers are not supported by gradient checkpointing
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True