add simple grpo

This commit is contained in:
Tong Li
2025-02-23 22:54:26 +08:00
parent 8e6c9a4ab3
commit ffd3878a1e
8 changed files with 253 additions and 21 deletions

View File

@@ -3,17 +3,13 @@ import torch
from .reward_utils import extract_solution, validate_response_structure
def math_reward_fn(input_ids, **kwargs):
# apply varifiable reward
# reward 10 points if the final answer is correct, reward 1 point if format is correct
gt_answer = kwargs["gt_answer"]
def math_reward_fn(input_ids, gt_answer, **kwargs):
tokenizer = kwargs["tokenizer"]
s, e = kwargs["response_start"], kwargs["response_end"]
reward = torch.tensor(0.0).to(input_ids.device)
if gt_answer is None:
return reward
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
decoded_final_answer = tokenizer.decode(input_ids, skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0))
final_answer, processed_str = extract_solution(decoded_final_answer)
format_valid = validate_response_structure(processed_str, kwargs["tags"])