mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-23 06:00:44 +00:00
rewrite reward fn
This commit is contained in:
parent
a6085ff676
commit
d06042b434
@ -127,7 +127,7 @@ class BaseConsumer:
|
|||||||
eval_statistics = {
|
eval_statistics = {
|
||||||
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
||||||
}
|
}
|
||||||
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
if hasattr(self, "wandb_run"):
|
if hasattr(self, "wandb_run"):
|
||||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||||
|
@ -1,8 +1,70 @@
|
|||||||
import torch
|
import torch
|
||||||
from math_verify import parse, verify
|
from latex2sympy2_extended import NormalizationConfig
|
||||||
|
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||||
|
|
||||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||||
|
|
||||||
|
CANNOT_PARSE_GT_ANSWER = -1
|
||||||
|
CANNOT_PARSE_PREDICTION = -2
|
||||||
|
SUCCESS = 1
|
||||||
|
MATCHING_FAIL = 0
|
||||||
|
|
||||||
|
|
||||||
|
def verify_math_representation(completion, gt_answer):
|
||||||
|
"""
|
||||||
|
Verify if the completion is a valid math representation of the gt_answer.
|
||||||
|
"""
|
||||||
|
target = (
|
||||||
|
ExprExtractionConfig(),
|
||||||
|
LatexExtractionConfig(
|
||||||
|
normalization_config=NormalizationConfig(
|
||||||
|
nits=False,
|
||||||
|
malformed_operators=False,
|
||||||
|
basic_latex=True,
|
||||||
|
boxed="all",
|
||||||
|
units=True,
|
||||||
|
),
|
||||||
|
boxed_match_priority=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
||||||
|
raise ValueError("gt_answer should be a string, please verify your training data.")
|
||||||
|
if not isinstance(completion, str) or len(completion) == 0:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
try:
|
||||||
|
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
||||||
|
if len(parsed_gt_answer) == 0:
|
||||||
|
return CANNOT_PARSE_GT_ANSWER
|
||||||
|
parsed_completion = parse(completion, extraction_config=target)
|
||||||
|
if len(parsed_completion) == 0:
|
||||||
|
return CANNOT_PARSE_PREDICTION
|
||||||
|
if verify(parsed_gt_answer, parsed_completion):
|
||||||
|
return SUCCESS
|
||||||
|
else:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
except Exception:
|
||||||
|
return MATCHING_FAIL
|
||||||
|
|
||||||
|
|
||||||
|
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
||||||
|
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
||||||
|
if math_verify_result == SUCCESS:
|
||||||
|
ans_acc += 1
|
||||||
|
reward += acc_score
|
||||||
|
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||||
|
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
|
||||||
|
",", ""
|
||||||
|
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
|
||||||
|
ans_acc += 1
|
||||||
|
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||||
|
# plain text answer cannot be parsed, but is correct
|
||||||
|
reward += acc_score
|
||||||
|
else:
|
||||||
|
reward += (
|
||||||
|
acc_score / 2
|
||||||
|
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||||
|
return reward, ans_acc
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
@ -36,9 +98,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
format_acc += 1
|
format_acc += 1
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
|
if format_valid and final_answer is not None:
|
||||||
ans_acc += 1
|
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||||
reward += acc_score
|
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
|
||||||
@ -88,9 +149,8 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
reward += format_score
|
reward += format_score
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
|
if format_valid and final_answer is not None:
|
||||||
ans_acc += 1
|
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||||
reward += acc_score
|
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
Loading…
Reference in New Issue
Block a user