From d06042b434396d5f2001e9540f308f932bd7477e Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 1 May 2025 11:28:05 +0800 Subject: [PATCH] rewrite reward fn --- .../coati/distributed/consumer.py | 2 +- .../coati/distributed/reward/reward_fn.py | 74 +++++++++++++++++-- 2 files changed, 68 insertions(+), 8 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index 31bd73e88..a3a0948bf 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -127,7 +127,7 @@ class BaseConsumer: eval_statistics = { k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics } - eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} + eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} if dist.get_rank() == 0: if hasattr(self, "wandb_run"): self.wandb_run.log(eval_statistics, step=eval_global_step) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 6844d700a..467a9b414 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,8 +1,70 @@ import torch -from math_verify import parse, verify +from latex2sympy2_extended import NormalizationConfig +from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure +CANNOT_PARSE_GT_ANSWER = -1 +CANNOT_PARSE_PREDICTION = -2 +SUCCESS = 1 +MATCHING_FAIL = 0 + + +def verify_math_representation(completion, gt_answer): + """ + Verify if the completion is a valid math representation of the gt_answer. + """ + target = ( + ExprExtractionConfig(), + LatexExtractionConfig( + normalization_config=NormalizationConfig( + nits=False, + malformed_operators=False, + basic_latex=True, + boxed="all", + units=True, + ), + boxed_match_priority=0, + ), + ) + if not isinstance(gt_answer, str) or len(gt_answer) == 0: + raise ValueError("gt_answer should be a string, please verify your training data.") + if not isinstance(completion, str) or len(completion) == 0: + return MATCHING_FAIL + try: + parsed_gt_answer = parse(gt_answer, extraction_config=target) + if len(parsed_gt_answer) == 0: + return CANNOT_PARSE_GT_ANSWER + parsed_completion = parse(completion, extraction_config=target) + if len(parsed_completion) == 0: + return CANNOT_PARSE_PREDICTION + if verify(parsed_gt_answer, parsed_completion): + return SUCCESS + else: + return MATCHING_FAIL + except Exception: + return MATCHING_FAIL + + +def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward): + math_verify_result = verify_math_representation(decoded_final_answer, gt_answer) + if math_verify_result == SUCCESS: + ans_acc += 1 + reward += acc_score + elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION: + if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace( + ",", "" + ) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""): + ans_acc += 1 + if math_verify_result == CANNOT_PARSE_GT_ANSWER: + # plain text answer cannot be parsed, but is correct + reward += acc_score + else: + reward += ( + acc_score / 2 + ) # not a valid latex math representation, but the answer is correct, receive half of the score + return reward, ans_acc + def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] @@ -36,9 +98,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): - ans_acc += 1 - reward += acc_score + if format_valid and final_answer is not None: + reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) reward = reward + length_reward @@ -88,9 +149,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): - ans_acc += 1 - reward += acc_score + if format_valid and final_answer is not None: + reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) reward = reward + length_reward if not eval_mode: