This commit is contained in:
YeAnbang
2024-06-27 07:20:28 +00:00
parent f3de5a025c
commit c8d1b4a968
12 changed files with 783 additions and 13 deletions

View File

@@ -299,6 +299,7 @@ if __name__ == "__main__":
parser.add_argument("--tp", type=int, default=1)
parser.add_argument("--pp", type=int, default=1)
parser.add_argument("--sp", type=int, default=1)
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="do_loss or simpo_loss")
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
parser.add_argument("--length_normalization", default=False, action="store_true")
@@ -341,6 +342,12 @@ if __name__ == "__main__":
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
args = parser.parse_args()
# fool proof hyperparameter setup
if args.loss_type == "simpo_loss":
args.length_normalization = True
args.gamma = args.gamma if args.gamma > 0 else 1.4
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
with open(args.config_file, "w") as f:
json.dump(args.__dict__, f, indent=4)