move out evaluation func

This commit is contained in:
Tong Li 2025-06-10 05:09:22 +00:00
parent c308b42f38
commit 923b23d5fe

View File

@ -128,6 +128,14 @@ class BaseProducer:
drop_last=True, drop_last=True,
collate_fn=collate_fn_grpo, collate_fn=collate_fn_grpo,
) )
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed":
self.evaluation_function = boxed_math_reward_fn
elif grpo_config["reward_fn_type"] == "code":
self.evaluation_function = code_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
self.eval_dataset_config = eval_dataset_config self.eval_dataset_config = eval_dataset_config
if self.eval_dataset_config is not None: if self.eval_dataset_config is not None:
@ -151,14 +159,6 @@ class BaseProducer:
), ),
collate_fn=collate_fn_grpo, collate_fn=collate_fn_grpo,
) )
if grpo_config["reward_fn_type"] == "think_answer_tags":
self.evaluation_function = math_reward_fn
elif grpo_config["reward_fn_type"] == "boxed":
self.evaluation_function = boxed_math_reward_fn
elif grpo_config["reward_fn_type"] == "code":
self.evaluation_function = code_reward_fn
else:
raise ValueError(f"Unknown evaluation function type {grpo_config['reward_fn_type']}")
else: else:
print("No eval dataset provided, skip eval") print("No eval dataset provided, skip eval")
self.device = get_current_device() self.device = get_current_device()