This commit is contained in:
Tong Li
2025-02-28 10:16:42 +08:00
parent f736d747e3
commit 070907dd7f
6 changed files with 74 additions and 26 deletions

View File

@@ -16,6 +16,7 @@ class VerifiableReward:
self,
input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None,
response_idx: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
bs = input_ids.size(0)
@@ -30,6 +31,7 @@ class VerifiableReward:
reward_fn(
input_ids[i],
gt_answer=gt_answer[i],
response_idx=response_idx[i],
**self.kwargs,
)
for i in range(bs)