diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 369023977..152689e06 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -120,16 +120,7 @@ class GRPOConsumer(BaseConsumer): "either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config." ) # Initialize verifiable reward. - response_format_tags = ( - { - "think_start": {"text": "", "num_occur": 1}, - "think_end": {"text": "", "num_occur": 1}, - "answer_start": {"text": "", "num_occur": 1}, - "answer_end": {"text": "", "num_occur": 1}, - } - if grpo_config.get("reward_fn_type") == "think_answer_tags" - else None - ) + response_format_tags = grpo_config.get("response_format_tags", None) reward_model_kwargs = { k: v for k, v in grpo_config.items() diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py index c9e8d2ab2..764325cdd 100644 --- a/applications/ColossalChat/coati/distributed/launch.py +++ b/applications/ColossalChat/coati/distributed/launch.py @@ -100,6 +100,7 @@ def launch_distributed( eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, evaluation_function_type=grpo_config["reward_fn_type"], + response_format_tags=grpo_config["response_format_tags"], eval_save_dir=eval_save_dir, eval_generation_config=eval_generation_config, project_name=project_name, diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 75879278c..916ae3326 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -46,6 +46,7 @@ class BaseProducer: eval_dataset_config=None, eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", project_name: str = None, run_name: str = None, @@ -148,6 +149,7 @@ class BaseProducer: self.evaluation_function = boxed_math_reward_fn else: raise ValueError(f"Unknown evaluation function type {evaluation_function_type}") + self.response_format_tags = response_format_tags else: raise ValueError("eval_dataset_config is not defined") self.device = get_current_device() @@ -217,6 +219,7 @@ class BaseProducer: eval_outputs["response_idx"][m][n], tokenizer=self.tokenizer, eval_mode=True, + tags=self.response_format_tags, ) for m in range(eval_outputs["input_ids"].size(0)) for n in range(eval_outputs["input_ids"].size(1)) @@ -324,6 +327,7 @@ class SimpleProducer(BaseProducer): eval_dataset_config=None, eval_interval=-1, # disable evaluation evaluation_function_type="think_answer_tags", + response_format_tags=None, eval_save_dir: str = "./eval", eval_generation_config={}, project_name: str = None, @@ -349,6 +353,7 @@ class SimpleProducer(BaseProducer): eval_dataset_config=eval_dataset_config, eval_interval=eval_interval, evaluation_function_type=evaluation_function_type, + response_format_tags=response_format_tags, eval_save_dir=eval_save_dir, project_name=project_name, run_name=run_name, diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py index 53f166d66..e1025442c 100644 --- a/applications/ColossalChat/rl_example.py +++ b/applications/ColossalChat/rl_example.py @@ -231,6 +231,16 @@ if __name__ == "__main__": "reward_fn_type": args.reward_type, "max_length": args.max_new_tokens + args.max_prompt_tokens, "max_new_tokens": args.max_new_tokens, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), } elif args.algo == "DAPO": # DAPO variant settings @@ -250,6 +260,16 @@ if __name__ == "__main__": "cache_length": min(1024, int(args.max_new_tokens / 4)), "filter_truncated_response": True, "reward_fn_type": args.reward_type, + "response_format_tags": ( + { + "think_start": {"text": "", "num_occur": 1}, + "think_end": {"text": "", "num_occur": 1}, + "answer_start": {"text": "", "num_occur": 1}, + "answer_end": {"text": "", "num_occur": 1}, + } + if args.reward_type == "think_answer_tags" + else None + ), } else: raise ValueError(f"Unsupported algorithm: {args.algo}")