mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 04:24:47 +00:00
Support overall loss, update KTO logging
This commit is contained in:
@@ -411,6 +411,7 @@ def train(args):
|
||||
use_cache=True,
|
||||
do_sample=True,
|
||||
temperature=0.7,
|
||||
apply_loss_mask=not args.disable_loss_mask,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
save_dir=args.save_path,
|
||||
save_interval=args.save_interval,
|
||||
@@ -498,6 +499,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||
parser.add_argument("--ptx_coef", type=float, default=0.0)
|
||||
parser.add_argument("--disable_loss_mask", default=False, action="store_true")
|
||||
parser.add_argument("--max_length", type=int, default=2048)
|
||||
parser.add_argument("--max_seq_len", type=int, default=256)
|
||||
parser.add_argument("--log_dir", default="logs", type=str)
|
||||
|
Reference in New Issue
Block a user