mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-01 09:07:51 +00:00
add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
This commit is contained in:
@@ -128,16 +128,14 @@ def train(args):
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
print("ref_model is None", args.disable_reference_model, ref_model is None)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
|
@@ -118,12 +118,11 @@ def train(args):
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=False, trust_remote_code=True)
|
||||
|
@@ -122,13 +122,11 @@ def train(args):
|
||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||
# )
|
||||
|
||||
if args.grad_checkpoint and args.lora_rank == 0:
|
||||
# lora layers are not supported by gradient checkpointing
|
||||
if args.grad_checkpoint:
|
||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||
model.gradient_checkpointing_enable()
|
||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||
elif args.lora_rank > 0:
|
||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_dir or args.pretrain, use_fast=False, trust_remote_code=True
|
||||
|
Reference in New Issue
Block a user