update reward fn

This commit is contained in:
Tong Li 2025-03-10 14:18:22 +08:00
parent 9d9d51614e
commit 754b16dfbf

View File

@ -5,7 +5,9 @@ from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"] tokenizer = kwargs["tokenizer"]
reward = torch.tensor(0.0).to(input_ids.device) reward = torch.tensor(0.0)
format_reward = torch.tensor(0.0)
acc_reward = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1] s, e = response_idx[0], response_idx[1]
if gt_answer is None: if gt_answer is None:
return reward return reward
@ -15,13 +17,21 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
final_answer, processed_str = extract_solution(decoded_final_answer) final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"]) format_valid = validate_response_structure(processed_str, kwargs["tags"])
if not format_valid:
return reward # Check format accuracy
else: if format_valid:
format_reward += 1.0
reward += 1.0 reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 2.0 # Check answer accuracy
return reward if (
final_answer is not None
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
):
acc_reward += 5.0
reward += 5.0
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
def gsm8k_reward_fn(input_ids, **kwargs): def gsm8k_reward_fn(input_ids, **kwargs):