diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py index 82e4b6ac1..854c2fcc2 100644 --- a/applications/ColossalChat/coati/distributed/producer.py +++ b/applications/ColossalChat/coati/distributed/producer.py @@ -128,6 +128,14 @@ class BaseProducer: drop_last=True, collate_fn=collate_fn_grpo, ) + if grpo_config["reward_fn_type"] == "think_answer_tags": + self.evaluation_function = math_reward_fn + elif grpo_config["reward_fn_type"] == "boxed": + self.evaluation_function = boxed_math_reward_fn + elif grpo_config["reward_fn_type"] == "code": + self.evaluation_function = code_reward_fn + else: + raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}") self.eval_dataset_config = eval_dataset_config if self.eval_dataset_config is not None: @@ -151,14 +159,6 @@ class BaseProducer: ), collate_fn=collate_fn_grpo, ) - if grpo_config["reward_fn_type"] == "think_answer_tags": - self.evaluation_function = math_reward_fn - elif grpo_config["reward_fn_type"] == "boxed": - self.evaluation_function = boxed_math_reward_fn - elif grpo_config["reward_fn_type"] == "code": - self.evaluation_function = code_reward_fn - else: - raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}") else: print("No eval dataset provided, skip eval") self.device = get_current_device()