update reward fn

This commit is contained in:
Tong Li 2025-03-06 10:53:48 +08:00
parent 678f5a9eca
commit d03cdea949

View File

@ -11,7 +11,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0))
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])
@ -20,7 +20,7 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
else:
reward += 1.0
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
reward = reward + 9.0
reward = reward + 2.0
return reward