diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index ef81bcbdd..c9e8d2ab2 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -1,4 +1,5 @@ import copy +import os import uuid from typing import Any, Dict, Optional diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 75879278c..78964afa1 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -149,7 +149,7 @@ class BaseProducer: else: raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") else: - raise ValueError("eval_dataset_config is not defined") + print("No eval dataset provided, skip eval") self.device = get_current_device() # init backend diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index bfa0ab7d0..5aec7f5a6 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -14,7 +14,7 @@ if __name__ == "__main__": "-ed", "--eval-dataset", type=str, - default='{"eval task name":"data_eval.jsonl"}', + default=None, help="Evaluation dataset for each task, please use json format to specify the dataset for each task. \ For example: {'task1':'data_eval_task1.jsonl', 'task2':'data_eval_task2.jsonl'}, the jsonl file should be in the same format as the training dataset. \ The key is the task name, and the value is the path to the jsonl file", @@ -265,10 +265,14 @@ if __name__ == "__main__": project_name=args.project, save_interval=args.save_interval, save_dir=os.path.join(args.save_dir, args.project.replace(" ", "_")), - eval_dataset_config={ - k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} - for k, v in json.loads(args.eval_dataset).items() - }, + eval_dataset_config=( + { + k: {"path": v, "max_length": args.max_prompt_tokens, "system_prompt": args.system_prompt} + for k, v in json.loads(args.eval_dataset).items() + } + if args.eval_dataset + else None + ), eval_interval=args.eval_interval, eval_save_dir=os.path.join(args.eval_save_dir, args.project.replace(" ", "_")), eval_generation_config=eval_generation_config,