Merge branch 'grpo-latest' of https://github.com/hpcaitech/ColossalAI into grpo-latest-dev

This commit is contained in:
YeAnbang 2025-05-28 17:34:52 +08:00
commit 58f8c9bb43
2 changed files with 10 additions and 6 deletions

View File

@ -151,7 +151,7 @@ class BaseProducer:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
self.response_format_tags = response_format_tags self.response_format_tags = response_format_tags
else: else:
raise ValueError("eval_dataset_config is not defined") print("No eval dataset provided, skip eval")
self.device = get_current_device() self.device = get_current_device()
# init backend # init backend

View File

@ -14,7 +14,7 @@ if __name__ == "__main__":
"-ed", "-ed",
"--eval-dataset", "--eval-dataset",
type=str, type=str,
default='{"eval task name":"data_eval.jsonl"}', default=None,
help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \ help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \
For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \ For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \
The key is the task name, and the value is the path to the jsonl file", The key is the task name, and the value is the path to the jsonl file",
@ -310,10 +310,14 @@ if __name__ == "__main__":
project_name=args.project, project_name=args.project,
save_interval=args.save_interval, save_interval=args.save_interval,
save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")),
eval_dataset_config={ eval_dataset_config=(
k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} {
for k, v in json.loads(args.eval_dataset).items() k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt}
}, for k, v in json.loads(args.eval_dataset).items()
}
if args.eval_dataset
else None
),
eval_interval=args.eval_interval, eval_interval=args.eval_interval,
eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")),
eval_generation_config=eval_generation_config, eval_generation_config=eval_generation_config,