From 9642b75581cb3e2cc050faba1e1a26390fbf0e3f Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Wed, 30 Apr 2025 22:59:54 +0800 Subject: [PATCH] upgrade reward math verification --- .../ColossalChat/coati/distributed/reward/reward_fn.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 14d340dc4..6844d700a 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,4 +1,5 @@ import torch +from math_verify import parse, verify from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure @@ -35,11 +36,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if ( - format_valid - and final_answer is not None - and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower() - ): + if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): ans_acc += 1 reward += acc_score @@ -91,7 +88,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower(): + if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): ans_acc += 1 reward += acc_score