support code generation tasks

This commit is contained in:
YeAnbang
2025-06-05 17:56:42 +08:00
parent ceb7065d6d
commit dc3033e68a
12 changed files with 1027 additions and 127 deletions

View File

@@ -2,6 +2,7 @@
Function-based reward verification module.
"""
import inspect
from typing import Any, Dict, List
import torch
@@ -15,7 +16,8 @@ class VerifiableReward:
def __call__(
self,
input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None,
gt_answer: List[str] = None,
test_cases: List[str] = None,
response_idx: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
@@ -26,18 +28,44 @@ class VerifiableReward:
# Loop through reward functions
for reward_fn in self.reward_fns:
# Apply the reward function to the entire batch at once
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
if "gt_answer" in inspect.getfullargspec(reward_fn).args:
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
elif "test_cases" in inspect.getfullargspec(reward_fn).args:
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
test_cases=test_cases[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
else:
reward_batch = torch.stack(
[
reward_fn(
input_ids[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)
],
dim=0,
)
rewards += reward_batch
return rewards