rewrite reward fn

This commit is contained in:
YeAnbang 2025-05-01 11:28:05 +08:00
parent a6085ff676
commit d06042b434
2 changed files with 68 additions and 8 deletions

View File

@ -127,7 +127,7 @@ class BaseConsumer:
eval_statistics = { eval_statistics = {
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
} }
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()} eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
if dist.get_rank() == 0: if dist.get_rank() == 0:
if hasattr(self, "wandb_run"): if hasattr(self, "wandb_run"):
self.wandb_run.log(eval_statistics, step=eval_global_step) self.wandb_run.log(eval_statistics, step=eval_global_step)

View File

@ -1,8 +1,70 @@
import torch import torch
from math_verify import parse, verify from latex2sympy2_extended import NormalizationConfig
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
CANNOT_PARSE_GT_ANSWER = -1
CANNOT_PARSE_PREDICTION = -2
SUCCESS = 1
MATCHING_FAIL = 0
def verify_math_representation(completion, gt_answer):
"""
Verify if the completion is a valid math representation of the gt_answer.
"""
target = (
ExprExtractionConfig(),
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
),
)
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
raise ValueError("gt_answer should be a string, please verify your training data.")
if not isinstance(completion, str) or len(completion) == 0:
return MATCHING_FAIL
try:
parsed_gt_answer = parse(gt_answer, extraction_config=target)
if len(parsed_gt_answer) == 0:
return CANNOT_PARSE_GT_ANSWER
parsed_completion = parse(completion, extraction_config=target)
if len(parsed_completion) == 0:
return CANNOT_PARSE_PREDICTION
if verify(parsed_gt_answer, parsed_completion):
return SUCCESS
else:
return MATCHING_FAIL
except Exception:
return MATCHING_FAIL
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
if math_verify_result == SUCCESS:
ans_acc += 1
reward += acc_score
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
",", ""
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
ans_acc += 1
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
# plain text answer cannot be parsed, but is correct
reward += acc_score
else:
reward += (
acc_score / 2
) # not a valid latex math representation, but the answer is correct, receive half of the score
return reward, ans_acc
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
@ -36,9 +98,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
format_acc += 1 format_acc += 1
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): if format_valid and final_answer is not None:
ans_acc += 1 reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
reward += acc_score
reward = reward + length_reward reward = reward + length_reward
@ -88,9 +149,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
reward += format_score reward += format_score
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())): if format_valid and final_answer is not None:
ans_acc += 1 reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
reward += acc_score
reward = reward + length_reward reward = reward + length_reward
if not eval_mode: if not eval_mode: