mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-08 12:30:42 +00:00
Support evaluation during training
This commit is contained in:
@@ -5,6 +5,7 @@ from .reward_utils import extract_boxed_solution, extract_solution, validate_res
|
||||
|
||||
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
eval_mode = kwargs.get("eval_mode", False)
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
acc_score = 10.0
|
||||
reward = torch.tensor(0.0)
|
||||
@@ -44,36 +45,23 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
|
||||
reward = reward + length_reward
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
|
||||
def gsm8k_reward_fn(input_ids, **kwargs):
|
||||
gt_answer = kwargs["gt_answer"]
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
s, e = kwargs["response_start"], kwargs["response_end"]
|
||||
reward = torch.tensor(0.0).to(input_ids.device)
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
final_answer, processed_str = extract_solution(decoded_final_answer)
|
||||
is_valid = True
|
||||
try:
|
||||
int(final_answer.strip())
|
||||
except Exception:
|
||||
is_valid = False
|
||||
|
||||
format_valid = validate_response_structure(processed_str, kwargs["tags"])
|
||||
if not is_valid or not format_valid:
|
||||
return reward
|
||||
if not eval_mode:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
else:
|
||||
reward += 1.0
|
||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||
reward = reward + 9.0
|
||||
return reward
|
||||
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"prediction": decoded_final_answer,
|
||||
"gold": gt_answer,
|
||||
"parsed": final_answer,
|
||||
"format_valid": format_acc.item(),
|
||||
"ans_valid": ans_acc.item(),
|
||||
}
|
||||
|
||||
|
||||
def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
eval_mode = kwargs.get("eval_mode", False)
|
||||
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
|
||||
format_score = 0.0
|
||||
acc_score = 10.0
|
||||
@@ -91,7 +79,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
length_reward = ((max_length - cache_length) - res_length) / cache_length * acc_score
|
||||
|
||||
if gt_answer is None:
|
||||
return reward
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
@@ -108,5 +96,15 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
reward += acc_score
|
||||
|
||||
reward = reward + length_reward
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
if not eval_mode:
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
else:
|
||||
prompt = tokenizer.decode(input_ids[:s], skip_special_tokens=True)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"prediction": decoded_final_answer,
|
||||
"gold": gt_answer,
|
||||
"parsed": final_answer,
|
||||
"format_valid": format_acc.item(),
|
||||
"ans_valid": ans_acc.item(),
|
||||
}
|
||||
|
Reference in New Issue
Block a user