This commit is contained in:
Tong Li
2025-02-28 10:16:42 +08:00
parent f736d747e3
commit 070907dd7f
6 changed files with 74 additions and 26 deletions

View File

@@ -3,12 +3,14 @@ import torch
from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, gt_answer, **kwargs):
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
reward = torch.tensor(0.0).to(input_ids.device)
s, e = response_idx[0], response_idx[1]
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0))
final_answer, processed_str = extract_solution(decoded_final_answer)
@@ -29,7 +31,7 @@ def gsm8k_reward_fn(input_ids, **kwargs):
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s:e], skip_special_tokens=True)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
is_valid = True
try: