mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-14 21:51:57 +00:00
[fix] revert reward update and evaluation (#6295)
* Revert "rewrite reward fn" This reverts commitd06042b434
. * Revert "upgrade reward math verification" This reverts commita6085ff676
. * Revert "fix bug" This reverts commit01640ebd65
. * Revert "reuse comm-group" This reverts commitbd61918dcf
. * Revert "Support evaluation during training" This reverts commit57a88395fe
.
This commit is contained in:
@@ -1,74 +1,10 @@
|
||||
import torch
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
|
||||
|
||||
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
|
||||
|
||||
CANNOT_PARSE_GT_ANSWER = -1
|
||||
CANNOT_PARSE_PREDICTION = -2
|
||||
SUCCESS = 1
|
||||
MATCHING_FAIL = 0
|
||||
|
||||
|
||||
def verify_math_representation(completion, gt_answer):
|
||||
"""
|
||||
Verify if the completion is a valid math representation of the gt_answer.
|
||||
"""
|
||||
target = (
|
||||
ExprExtractionConfig(),
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
),
|
||||
)
|
||||
if not isinstance(gt_answer, str) or len(gt_answer) == 0:
|
||||
raise ValueError("gt_answer should be a string, please verify your training data.")
|
||||
if not isinstance(completion, str) or len(completion) == 0:
|
||||
return MATCHING_FAIL
|
||||
try:
|
||||
parsed_gt_answer = parse(gt_answer, extraction_config=target)
|
||||
if len(parsed_gt_answer) == 0:
|
||||
return CANNOT_PARSE_GT_ANSWER
|
||||
parsed_completion = parse(completion, extraction_config=target)
|
||||
if len(parsed_completion) == 0:
|
||||
return CANNOT_PARSE_PREDICTION
|
||||
if verify(parsed_gt_answer, parsed_completion):
|
||||
return SUCCESS
|
||||
else:
|
||||
return MATCHING_FAIL
|
||||
except Exception:
|
||||
return MATCHING_FAIL
|
||||
|
||||
|
||||
def verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward):
|
||||
math_verify_result = verify_math_representation(decoded_final_answer, gt_answer)
|
||||
if math_verify_result == SUCCESS:
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
elif math_verify_result == CANNOT_PARSE_GT_ANSWER or math_verify_result == CANNOT_PARSE_PREDICTION:
|
||||
if decoded_final_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(
|
||||
",", ""
|
||||
) == gt_answer.strip().replace(" ", "").replace("{", "").replace("}", "").replace(",", ""):
|
||||
ans_acc += 1
|
||||
if math_verify_result == CANNOT_PARSE_GT_ANSWER:
|
||||
# plain text answer cannot be parsed, but is correct
|
||||
reward += acc_score
|
||||
else:
|
||||
reward += (
|
||||
acc_score / 2
|
||||
) # not a valid latex math representation, but the answer is correct, receive half of the score
|
||||
return reward, ans_acc
|
||||
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
eval_mode = kwargs.get("eval_mode", False)
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
acc_score = 10.0
|
||||
reward = torch.tensor(0.0)
|
||||
@@ -98,28 +34,46 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
format_acc += 1
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if format_valid and final_answer is not None:
|
||||
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||
if (
|
||||
format_valid
|
||||
and final_answer is not None
|
||||
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
|
||||
):
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
|
||||
reward = reward + length_reward
|
||||
|
||||
if not eval_mode:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
gt_answer = kwargs["gt_answer"]
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
s, e = kwargs["response_start"], kwargs["response_end"]
|
||||
reward = torch.tensor(0.0).to(input_ids.device)
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
is_valid = True
|
||||
try:
|
||||
int(final_answer.strip())
|
||||
except Exception:
|
||||
is_valid = False
|
||||
|
||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||
if not is_valid or not format_valid:
|
||||
return reward
|
||||
else:
|
||||
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"prediction": decoded_final_answer,
|
||||
"gold": gt_answer,
|
||||
"parsed": final_answer,
|
||||
"format_valid": format_acc.item(),
|
||||
"ans_valid": ans_acc.item(),
|
||||
}
|
||||
reward += 1.0
|
||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||
reward = reward + 9.0
|
||||
return reward
|
||||
|
||||
|
||||
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
eval_mode = kwargs.get("eval_mode", False)
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
format_score = 0.0
|
||||
acc_score = 10.0
|
||||
@@ -137,7 +91,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||
|
||||
if gt_answer is None:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
return reward
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
@@ -149,19 +103,10 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
reward += format_score
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if format_valid and final_answer is not None:
|
||||
reward, ans_acc = verify_model_answer(decoded_final_answer, gt_answer, ans_acc, acc_score, reward)
|
||||
if format_valid and final_answer is not None and gt_answer.strip().lower() == final_answer.strip().lower():
|
||||
ans_acc += 1
|
||||
reward += acc_score
|
||||
|
||||
reward = reward + length_reward
|
||||
if not eval_mode:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
else:
|
||||
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"prediction": decoded_final_answer,
|
||||
"gold": gt_answer,
|
||||
"parsed": final_answer,
|
||||
"format_valid": format_acc.item(),
|
||||
"ans_valid": ans_acc.item(),
|
||||
}
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
Reference in New Issue
Block a user