mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-22 21:49:08 +00:00
rewrite reward fn
This commit is contained in:
parent
a6085ff676
commit
d06042b434
@ -127,7 +127,7 @@ class BaseConsumer:
|
||||
eval_statistics = {
|
||||
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
||||
}
|
||||
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||
if dist.get_rank() == 0:
|
||||
if hasattr(self, "wandb_run"):
|
||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||
|
@ -1,8 +1,70 @@
|
||||
import torch
|
||||
from math_verify import parse, verify
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||
|
||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||
|
||||
CANNOT_PARSE_GT_ANSWER = -1
|
||||
CANNOT_PARSE_PREDICTION = -2
|
||||
SUCCESS = 1
|
||||
MATCHING_FAIL = 0
|
||||
|
||||
|
||||
def verify_math_representation(completion, gt_answer):
|
||||
"""
|
||||
Verify if the completion is a valid math representation of the gt_answer.
|
||||
"""
|
||||
target = (
|
||||
ExprExtractionConfig(),
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
),
|
||||
)
|
||||
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
||||
raise ValueError("gt_answer should be a string, please verify your training data.")
|
||||
if not isinstance(completion, str) or len(completion) == 0:
|
||||
return MATCHING_FAIL
|
||||
try:
|
||||
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
||||
if len(parsed_gt_answer) == 0:
|
||||
return CANNOT_PARSE_GT_ANSWER
|
||||
parsed_completion = parse(completion, extraction_config=target)
|
||||
if len(parsed_completion) == 0:
|
||||
return CANNOT_PARSE_PREDICTION
|
||||
if verify(parsed_gt_answer, parsed_completion):
|
||||
return SUCCESS
|
||||
else:
|
||||
return MATCHING_FAIL
|
||||
except Exception:
|
||||
return MATCHING_FAIL
|
||||
|
||||
|
||||
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
||||
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
||||
if math_verify_result == SUCCESS:
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
|
||||
",", ""
|
||||
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
|
||||
ans_acc += 1
|
||||
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||
# plain text answer cannot be parsed, but is correct
|
||||
reward += acc_score
|
||||
else:
|
||||
reward += (
|
||||
acc_score / 2
|
||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||
return reward, ans_acc
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
@ -36,9 +98,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
format_acc += 1
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
if format_valid and final_answer is not None:
|
||||
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||
|
||||
reward = reward + length_reward
|
||||
|
||||
@ -88,9 +149,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
reward += format_score
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
if format_valid and final_answer is not None:
|
||||
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||
|
||||
reward = reward + length_reward
|
||||
if not eval_mode:
|
||||
|
Loading…
Reference in New Issue
Block a user