add simple grpo

This commit is contained in:
Tong Li
2025-02-23 22:54:26 +08:00
parent 8e6c9a4ab3
commit ffd3878a1e
8 changed files with 253 additions and 21 deletions

View File

@@ -8,33 +8,27 @@ import torch
class VerifiableReward:
def __init__(self, reward_fn: List[callable], reward_args: List[Dict[str, Any]]):
self.reward_fn = reward_fn
self.reward_args = reward_args
def __init__(self, reward_fns: List[callable], **kwargs: List[Dict[str, Any]]):
self.reward_fns = reward_fns
self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
response_start: List[int] = None,
response_end: List[int] = None,
gt_answer: List[str] = None,
gt_answer: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
bs = input_ids.size(0)
# Initialize reward
reward = torch.zeros(bs, device=input_ids.device)
rewards = torch.zeros(bs, device=input_ids.device)
# Loop through reward functions
for reward_fn in self.reward_fn_list:
for reward_fn in self.reward_fns:
# Apply the reward function to the entire batch at once
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
attention_mask[i],
response_start=response_start[i],
response_end=response_end[i],
gt_answer=gt_answer[i],
**self.kwargs,
)