add SimPO

This commit is contained in:
YeAnbang
2024-06-24 02:12:20 +00:00
parent 84eab13078
commit 82aecd6374
14 changed files with 128 additions and 70 deletions

View File

@@ -116,7 +116,7 @@ def train(args):
else:
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
disable_dropout(model)
if args.enable_reference_model:
if not args.disable_reference_model:
if args.use_flash_attn:
ref_model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
@@ -128,7 +128,7 @@ def train(args):
disable_dropout(ref_model)
else:
ref_model = None
print("ref_model is None", args.disable_reference_model, ref_model is None)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
@@ -255,6 +255,9 @@ def train(args):
save_interval=args.save_interval,
save_dir=args.save_dir,
coordinator=coordinator,
beta=args.beta,
gamma=args.gamma,
length_normalization=args.length_normalization,
)
trainer.fit(
@@ -296,6 +299,9 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
parser.add_argument("--length_normalization", default=False, action="store_true")
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
@@ -312,7 +318,12 @@ if __name__ == "__main__":
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--enable_reference_model", type=bool, default=True)
parser.add_argument(
"--disable_reference_model",
action="store_true",
default=False,
help="Disable the reference model (enabled by default)",
)
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
parser.add_argument(