mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-07 20:10:17 +00:00
add orpo
This commit is contained in:
@@ -299,6 +299,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--tp", type=int, default=1)
|
||||
parser.add_argument("--pp", type=int, default=1)
|
||||
parser.add_argument("--sp", type=int, default=1)
|
||||
parser.add_argument("--loss_type", type=str, default="dpo_loss", help="do_loss or simpo_loss")
|
||||
parser.add_argument("--beta", type=float, default=0.1, help="beta in DPO loss")
|
||||
parser.add_argument("--gamma", type=float, default=0.0, help="gamma in SimPO loss")
|
||||
parser.add_argument("--length_normalization", default=False, action="store_true")
|
||||
@@ -341,6 +342,12 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# fool proof hyperparameter setup
|
||||
if args.loss_type == "simpo_loss":
|
||||
args.length_normalization = True
|
||||
args.gamma = args.gamma if args.gamma > 0 else 1.4
|
||||
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
with open(args.config_file, "w") as f:
|
||||
json.dump(args.__dict__, f, indent=4)
|
||||
|
Reference in New Issue
Block a user