diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index 369023977..152689e06 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -120,16 +120,7 @@ class GRPOConsumer(BaseConsumer):
"either max_tokens (vllm) or max_new_tokens (transformers) must be set in generate_config."
)
# Initialize verifiable reward.
- response_format_tags = (
- {
- "think_start": {"text": "", "num_occur": 1},
- "think_end": {"text": "", "num_occur": 1},
- "answer_start": {"text": "", "num_occur": 1},
- "answer_end": {"text": "", "num_occur": 1},
- }
- if grpo_config.get("reward_fn_type") == "think_answer_tags"
- else None
- )
+ response_format_tags = grpo_config.get("response_format_tags", None)
reward_model_kwargs = {
k: v
for k, v in grpo_config.items()
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index c9e8d2ab2..764325cdd 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -100,6 +100,7 @@ def launch_distributed(
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=grpo_config["reward_fn_type"],
+ response_format_tags=grpo_config["response_format_tags"],
eval_save_dir=eval_save_dir,
eval_generation_config=eval_generation_config,
project_name=project_name,
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index 75879278c..916ae3326 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -46,6 +46,7 @@ class BaseProducer:
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
+ response_format_tags=None,
eval_save_dir: str = "./eval",
project_name: str = None,
run_name: str = None,
@@ -148,6 +149,7 @@ class BaseProducer:
self.evaluation_function = boxed_math_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {evaluation_function_type}")
+ self.response_format_tags = response_format_tags
else:
raise ValueError("eval_dataset_config is not defined")
self.device = get_current_device()
@@ -217,6 +219,7 @@ class BaseProducer:
eval_outputs["response_idx"][m][n],
tokenizer=self.tokenizer,
eval_mode=True,
+ tags=self.response_format_tags,
)
for m in range(eval_outputs["input_ids"].size(0))
for n in range(eval_outputs["input_ids"].size(1))
@@ -324,6 +327,7 @@ class SimpleProducer(BaseProducer):
eval_dataset_config=None,
eval_interval=-1, # disable evaluation
evaluation_function_type="think_answer_tags",
+ response_format_tags=None,
eval_save_dir: str = "./eval",
eval_generation_config={},
project_name: str = None,
@@ -349,6 +353,7 @@ class SimpleProducer(BaseProducer):
eval_dataset_config=eval_dataset_config,
eval_interval=eval_interval,
evaluation_function_type=evaluation_function_type,
+ response_format_tags=response_format_tags,
eval_save_dir=eval_save_dir,
project_name=project_name,
run_name=run_name,
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 53f166d66..e1025442c 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -231,6 +231,16 @@ if __name__ == "__main__":
"reward_fn_type": args.reward_type,
"max_length": args.max_new_tokens + args.max_prompt_tokens,
"max_new_tokens": args.max_new_tokens,
+ "response_format_tags": (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if args.reward_type == "think_answer_tags"
+ else None
+ ),
}
elif args.algo == "DAPO":
# DAPO variant settings
@@ -250,6 +260,16 @@ if __name__ == "__main__":
"cache_length": min(1024, int(args.max_new_tokens / 4)),
"filter_truncated_response": True,
"reward_fn_type": args.reward_type,
+ "response_format_tags": (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if args.reward_type == "think_answer_tags"
+ else None
+ ),
}
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")