mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-05 11:02:05 +00:00
@@ -127,7 +127,7 @@ class BaseConsumer:
|
|||||||
eval_statistics = {
|
eval_statistics = {
|
||||||
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
k: eval_statistics[k] + local_eval_result[k] for k in eval_statistics
|
||||||
}
|
}
|
||||||
eval_statistics = {"eval/" + k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
eval_statistics = {k: (v[0] / v[1]).item() for k, v in eval_statistics.items()}
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
if hasattr(self, "wandb_run"):
|
if hasattr(self, "wandb_run"):
|
||||||
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
self.wandb_run.log(eval_statistics, step=eval_global_step)
|
||||||
|
@@ -1,70 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from latex2sympy2_extended import NormalizationConfig
|
from math_verify import parse, verify
|
||||||
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
|
||||||
|
|
||||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||||
|
|
||||||
CANNOT_PARSE_GT_ANSWER = -1
|
|
||||||
CANNOT_PARSE_PREDICTION = -2
|
|
||||||
SUCCESS = 1
|
|
||||||
MATCHING_FAIL = 0
|
|
||||||
|
|
||||||
|
|
||||||
def verify_math_representation(completion, gt_answer):
|
|
||||||
"""
|
|
||||||
Verify if the completion is a valid math representation of the gt_answer.
|
|
||||||
"""
|
|
||||||
target = (
|
|
||||||
ExprExtractionConfig(),
|
|
||||||
LatexExtractionConfig(
|
|
||||||
normalization_config=NormalizationConfig(
|
|
||||||
nits=False,
|
|
||||||
malformed_operators=False,
|
|
||||||
basic_latex=True,
|
|
||||||
boxed="all",
|
|
||||||
units=True,
|
|
||||||
),
|
|
||||||
boxed_match_priority=0,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
|
||||||
raise ValueError("gt_answer should be a string, please verify your training data.")
|
|
||||||
if not isinstance(completion, str) or len(completion) == 0:
|
|
||||||
return MATCHING_FAIL
|
|
||||||
try:
|
|
||||||
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
|
||||||
if len(parsed_gt_answer) == 0:
|
|
||||||
return CANNOT_PARSE_GT_ANSWER
|
|
||||||
parsed_completion = parse(completion, extraction_config=target)
|
|
||||||
if len(parsed_completion) == 0:
|
|
||||||
return CANNOT_PARSE_PREDICTION
|
|
||||||
if verify(parsed_gt_answer, parsed_completion):
|
|
||||||
return SUCCESS
|
|
||||||
else:
|
|
||||||
return MATCHING_FAIL
|
|
||||||
except Exception:
|
|
||||||
return MATCHING_FAIL
|
|
||||||
|
|
||||||
|
|
||||||
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
|
||||||
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
|
||||||
if math_verify_result == SUCCESS:
|
|
||||||
ans_acc += 1
|
|
||||||
reward += acc_score
|
|
||||||
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
|
|
||||||
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
|
|
||||||
",", ""
|
|
||||||
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
|
|
||||||
ans_acc += 1
|
|
||||||
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
|
||||||
# plain text answer cannot be parsed, but is correct
|
|
||||||
reward += acc_score
|
|
||||||
else:
|
|
||||||
reward += (
|
|
||||||
acc_score / 2
|
|
||||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
|
||||||
return reward, ans_acc
|
|
||||||
|
|
||||||
|
|
||||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||||
tokenizer = kwargs["tokenizer"]
|
tokenizer = kwargs["tokenizer"]
|
||||||
@@ -98,8 +36,9 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
format_acc += 1
|
format_acc += 1
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if format_valid and final_answer is not None:
|
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
|
||||||
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
ans_acc += 1
|
||||||
|
reward += acc_score
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
|
|
||||||
@@ -149,8 +88,9 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||||||
reward += format_score
|
reward += format_score
|
||||||
|
|
||||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||||
if format_valid and final_answer is not None:
|
if format_valid and final_answer is not None and verify(parse(gt_answer.strip()), parse(final_answer.strip())):
|
||||||
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
ans_acc += 1
|
||||||
|
reward += acc_score
|
||||||
|
|
||||||
reward = reward + length_reward
|
reward = reward + length_reward
|
||||||
if not eval_mode:
|
if not eval_mode:
|
||||||
|
Reference in New Issue
Block a user