fix logging rollouts

This commit is contained in:
YeAnbang
2025-05-17 21:12:58 +08:00
parent 03b41d6fb5
commit 107470a360
5 changed files with 56 additions and 24 deletions

View File

@@ -118,6 +118,9 @@ if __name__ == "__main__":
parser.add_argument(
"-esd", "--eval-save-dir", type=str, default="./eval", help="Directory for saving evaluation results."
)
parser.add_argument(
"-rsd", "--rollout-save-dir", type=str, default="./rollouts", help="Directory for saving rollout loggings."
)
args = parser.parse_args()
if args.train_minibatch_size is None:
@@ -269,4 +272,6 @@ if __name__ == "__main__":
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,
log_rollout_interval=20,
rollout_log_file=os.path.join(args.rollout_save_dir, args.project.replace(" ", "_") + ".jsonl"),
)