mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 12:01:39 +00:00
add SimPO
This commit is contained in:
@@ -116,7 +116,7 @@ def train(args):
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||
disable_dropout(model)
|
||||
if args.enable_reference_model:
|
||||
if not args.disable_reference_model:
|
||||
if args.use_flash_attn:
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||
args.pretrain,
|
||||
@@ -128,7 +128,7 @@ def train(args):
|
||||
disable_dropout(ref_model)
|
||||
else:
|
||||
ref_model = None
|
||||
|
||||
print("ref_model is None", args.disable_reference_model, ref_model is None)
|
||||
if args.lora_rank > 0:
|
||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||
|
||||
@@ -255,6 +255,9 @@ def train(args):
|
||||
save_interval=args.save_interval,
|
||||
save_dir=args.save_dir,
|
||||
coordinator=coordinator,
|
||||
beta=args.beta,
|
||||
gamma=args.gamma,
|
||||
length_normalization=args.length_normalization,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
@@ -296,6 +299,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
|
||||
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
|
||||
parser.add_argument("--length_normalization", default=False, action="store_true")
|
||||
parser.add_argument("--enable_sequence_parallelism", default=False, action="store_true")
|
||||
parser.add_argument("--zero_stage", type=int, default=0, help="Zero stage", choices=[0, 1, 2])
|
||||
parser.add_argument("--zero_cpu_offload", default=False, action="store_true")
|
||||
@@ -312,7 +318,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--max_length", type=int, default=2048, help="Model max length")
|
||||
parser.add_argument("--max_epochs", type=int, default=3)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--enable_reference_model", type=bool, default=True)
|
||||
parser.add_argument(
|
||||
"--disable_reference_model",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable the reference model (enabled by default)",
|
||||
)
|
||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument(
|
||||
|
Reference in New Issue
Block a user