fix default eval setting (#6321)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-05-22 11:52:41 +08:00 committed by YeAnbang
parent 2a39d3afd9
commit 382307a62c
2 changed files with 19 additions and 6 deletions

View File

@ -151,7 +151,7 @@ class BaseProducer:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags
else:
raise ValueError("eval_dataset_config is not defined")
print("No eval dataset provided, skip eval")
self.device = get_current_device()
# init backend

View File

@ -10,7 +10,16 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
parser.add_argument("-p", "--project", type=str, default="GRPO-V3", help="Project name.")
parser.add_argument(
"-ed",
"--eval-dataset",
type=str,
default=None,
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
The key is the task name, and the value is the path to the jsonl file",
)
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
parser.add_argument("-e", "--num-episodes", type=int, default=1, help="Number of episodes to train.")
# Distributed training parameters
@ -301,10 +310,14 @@ if __name__ == "__main__":
project_name=args.project,
save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config={
eval_dataset_config=(
{
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
for k, v in json.loads(args.eval_dataset).items()
},
}
if args.eval_dataset
else None
),
eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config,