add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

This commit is contained in:
YeAnbang
2024-07-10 10:17:08 +00:00
parent 16f3451fe2
commit d888c3787c
13 changed files with 1175 additions and 26 deletions

View File

@@ -128,16 +128,14 @@ def train(args):
disable_dropout(ref_model)
else:
ref_model = None
print("ref_model is None", args.disable_reference_model, ref_model is None)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)