diff --git a/applications/ColossalChat/coati/distributed/consumer.py b/applications/ColossalChat/coati/distributed/consumer.py index a3a0948bf..31bd73e88 100644 --- a/applications/ColossalChat/coati/distributed/consumer.py +++ b/applications/ColossalChat/coati/distributed/consumer.py @@ -127,7 +127,7 @@ class BaseConsumer: eval_statistics = { k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics } - eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} + eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} if dist.get_rank() == 0: if hasattr(self, "wandb_run"): self.wandb_run.log(eval_statistics, step=eval_global_step) diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index 467a9b414..6844d700a 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -1,70 +1,8 @@ import torch -from latex2sympy2_extended import NormalizationConfig -from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify +from math_verify import parse, verify from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure -CANNOT_PARSE_GT_ANSWER = -1 -CANNOT_PARSE_PREDICTION = -2 -SUCCESS = 1 -MATCHING_FAIL = 0 - - -def verify_math_representation(completion, gt_answer): - """ - Verify if the completion is a valid math representation of the gt_answer. - """ - target = ( - ExprExtractionConfig(), - LatexExtractionConfig( - normalization_config=NormalizationConfig( - nits=False, - malformed_operators=False, - basic_latex=True, - boxed="all", - units=True, - ), - boxed_match_priority=0, - ), - ) - if not isinstance(gt_answer, str) or len(gt_answer) == 0: - raise ValueError("gt_answer should be a string, please verify your training data.") - if not isinstance(completion, str) or len(completion) == 0: - return MATCHING_FAIL - try: - parsed_gt_answer = parse(gt_answer, extraction_config=target) - if len(parsed_gt_answer) == 0: - return CANNOT_PARSE_GT_ANSWER - parsed_completion = parse(completion, extraction_config=target) - if len(parsed_completion) == 0: - return CANNOT_PARSE_PREDICTION - if verify(parsed_gt_answer, parsed_completion): - return SUCCESS - else: - return MATCHING_FAIL - except Exception: - return MATCHING_FAIL - - -def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward): - math_verify_result = verify_math_representation(decoded_final_answer, gt_answer) - if math_verify_result == SUCCESS: - ans_acc += 1 - reward += acc_score - elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION: - if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace( - ",", "" - ) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""): - ans_acc += 1 - if math_verify_result == CANNOT_PARSE_GT_ANSWER: - # plain text answer cannot be parsed, but is correct - reward += acc_score - else: - reward += ( - acc_score / 2 - ) # not a valid latex math representation, but the answer is correct, receive half of the score - return reward, ans_acc - def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): tokenizer = kwargs["tokenizer"] @@ -98,8 +36,9 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): format_acc += 1 # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None: - reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) + if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): + ans_acc += 1 + reward += acc_score reward = reward + length_reward @@ -149,8 +88,9 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward += format_score # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid - if format_valid and final_answer is not None: - reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward) + if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): + ans_acc += 1 + reward += acc_score reward = reward + length_reward if not eval_mode: