add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint

This commit is contained in:
YeAnbang
2024-07-10 10:17:08 +00:00
parent 16f3451fe2
commit d888c3787c
13 changed files with 1175 additions and 26 deletions

View File

@@ -128,16 +128,14 @@ def train(args):
disable_dropout(ref_model)
else:
ref_model = None
print("ref_model is None", args.disable_reference_model, ref_model is None)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)

View File

@@ -118,12 +118,11 @@ def train(args):
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.grad_checkpoint and args.lora_rank == 0:
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)

View File

@@ -122,13 +122,11 @@ def train(args):
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
if args.grad_checkpoint and args.lora_rank == 0:
# lora layers are not supported by gradient checkpointing
if args.grad_checkpoint:
# Note, for some models, lora may not be compatible with gradient checkpointing
model.gradient_checkpointing_enable()
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
elif args.lora_rank > 0:
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
# configure tokenizer
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True