mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-20 08:53:22 +00:00
202 lines
7.7 KiB
Python
202 lines
7.7 KiB
Python
import torch
|
|
from latex2sympy2_extended import NormalizationConfig
|
|
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
|
|
|
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
|
|
|
CANNOT_PARSE_GT_ANSWER = -1
|
|
CANNOT_PARSE_PREDICTION = -2
|
|
SUCCESS = 1
|
|
MATCHING_FAIL = 0
|
|
|
|
|
|
def verify_math_representation(completion, gt_answer):
|
|
"""
|
|
Verify if the completion is a valid math representation of the gt_answer.
|
|
"""
|
|
if not completion.startswith("\\boxed{"):
|
|
completion = "\\boxed{" + completion + "}"
|
|
if not gt_answer.startswith("\\boxed{"):
|
|
gt_answer = "\\boxed{" + gt_answer + "}"
|
|
target = (
|
|
ExprExtractionConfig(),
|
|
LatexExtractionConfig(
|
|
normalization_config=NormalizationConfig(
|
|
nits=False,
|
|
malformed_operators=False,
|
|
basic_latex=True,
|
|
boxed="all",
|
|
units=True,
|
|
),
|
|
boxed_match_priority=0,
|
|
),
|
|
)
|
|
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
|
raise ValueError("gt_answer should be a string, please verify your training data.")
|
|
if not isinstance(completion, str) or len(completion) == 0:
|
|
return MATCHING_FAIL
|
|
try:
|
|
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
|
if len(parsed_gt_answer) == 0:
|
|
return CANNOT_PARSE_GT_ANSWER
|
|
parsed_completion = parse(completion, extraction_config=target)
|
|
if len(parsed_completion) == 0:
|
|
return CANNOT_PARSE_PREDICTION
|
|
if verify(parsed_gt_answer, parsed_completion):
|
|
return SUCCESS
|
|
else:
|
|
return MATCHING_FAIL
|
|
except Exception:
|
|
return MATCHING_FAIL
|
|
|
|
|
|
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
|
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
|
exact_match_result = (
|
|
SUCCESS
|
|
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
|
|
== gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", "")
|
|
else MATCHING_FAIL
|
|
)
|
|
if math_verify_result == SUCCESS:
|
|
ans_acc += 1
|
|
reward += acc_score
|
|
elif exact_match_result == SUCCESS:
|
|
# sometimes for answers that's not a (valid) math expression, math_verify will fail
|
|
ans_acc += 1
|
|
if math_verify_result == CANNOT_PARSE_PREDICTION:
|
|
reward += (
|
|
acc_score / 2
|
|
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
|
else:
|
|
reward += acc_score
|
|
return reward, ans_acc
|
|
|
|
|
|
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|
tokenizer = kwargs["tokenizer"]
|
|
eval_mode = kwargs.get("eval_mode", False)
|
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
|
acc_score = 10.0
|
|
reward = torch.tensor(0.0)
|
|
format_acc = torch.tensor(0.0)
|
|
ans_acc = torch.tensor(0.0)
|
|
s, e = response_idx[0], response_idx[1]
|
|
|
|
length_reward = 0.0
|
|
res_length = e.item() - s.item() + 1
|
|
if not eval_mode:
|
|
max_new_tokens = kwargs["max_new_tokens"]
|
|
else:
|
|
max_new_tokens = -1 # for eval mode, we don't need to check the length
|
|
if not eval_mode and soft_over_length_punishment:
|
|
cache_length = kwargs["cache_length"]
|
|
if max_new_tokens - cache_length < res_length < max_new_tokens:
|
|
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
|
|
|
|
if gt_answer is None:
|
|
raise ValueError("no gt_answer is provided, please check your training dataset.")
|
|
|
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
|
final_answer, processed_str = extract_solution(decoded_final_answer)
|
|
|
|
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
|
|
|
# Check format accuracy
|
|
if format_valid:
|
|
format_acc += 1
|
|
|
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
|
if final_answer is not None:
|
|
if eval_mode or format_valid:
|
|
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
|
|
if not eval_mode:
|
|
reward = reward + length_reward
|
|
|
|
# Check if the sequence is over length
|
|
if not eval_mode and res_length >= max_new_tokens:
|
|
reward *= 0.0
|
|
|
|
if not eval_mode:
|
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
|
else:
|
|
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
|
return {
|
|
"prompt": prompt,
|
|
"prediction": decoded_final_answer,
|
|
"gold": gt_answer,
|
|
"parsed": final_answer,
|
|
"format_valid": format_acc.item(),
|
|
"ans_valid": ans_acc.item(),
|
|
"response_length": res_length,
|
|
"reward": reward.item(),
|
|
}
|
|
|
|
|
|
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|
tokenizer = kwargs["tokenizer"]
|
|
eval_mode = kwargs.get("eval_mode", False)
|
|
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
|
format_score = 0.0
|
|
acc_score = 10.0
|
|
reward = torch.tensor(0.0)
|
|
format_acc = torch.tensor(0.0)
|
|
ans_acc = torch.tensor(0.0)
|
|
s, e = response_idx[0], response_idx[1]
|
|
|
|
length_reward = 0.0
|
|
res_length = e.item() - s.item() + 1
|
|
if not eval_mode:
|
|
max_new_tokens = kwargs["max_new_tokens"]
|
|
else:
|
|
max_new_tokens = -1 # for eval mode, we don't need to check the length
|
|
if not eval_mode and soft_over_length_punishment:
|
|
cache_length = kwargs["cache_length"]
|
|
if max_new_tokens - cache_length < res_length < max_new_tokens:
|
|
length_reward = ((max_new_tokens - cache_length) - res_length) / cache_length * acc_score
|
|
|
|
if gt_answer is None:
|
|
raise ValueError("no gt_answer is provided, please check your training dataset.")
|
|
|
|
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
|
|
|
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
|
final_answer = extract_boxed_solution(decoded_final_answer)
|
|
format_valid = final_answer is not None
|
|
if "tags" in kwargs and kwargs["tags"]:
|
|
tags = kwargs["tags"]
|
|
format_valid = format_valid and all(
|
|
[decoded_final_answer.count(tags[tag]["text"]) == tags[tag]["num_occur"] for tag in tags]
|
|
)
|
|
# Check format accuracy
|
|
if format_valid:
|
|
format_acc += 1
|
|
reward += format_score
|
|
|
|
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
|
if final_answer is not None:
|
|
if eval_mode or format_valid:
|
|
reward, ans_acc = verify_model_answer(final_answer, gt_answer, ans_acc, acc_score, reward)
|
|
if not eval_mode:
|
|
reward = reward + length_reward
|
|
|
|
# Check if the sequence is over length
|
|
if not eval_mode and res_length >= max_new_tokens:
|
|
reward *= 0.0
|
|
|
|
if not eval_mode:
|
|
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
|
else:
|
|
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
|
return {
|
|
"prompt": prompt,
|
|
"prediction": decoded_final_answer,
|
|
"gold": gt_answer,
|
|
"parsed": final_answer,
|
|
"format_valid": format_acc.item(),
|
|
"ans_valid": ans_acc.item(),
|
|
"response_length": res_length,
|
|
"reward": reward.item(),
|
|
}
|