mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-17 07:00:37 +00:00
support code generation tasks
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
Function-based reward verification module.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
@@ -15,7 +16,8 @@ class VerifiableReward:
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
gt_answer: List[torch.Tensor] = None,
|
||||
gt_answer: List[str] = None,
|
||||
test_cases: List[str] = None,
|
||||
response_idx: List[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Get batch size
|
||||
@@ -26,18 +28,44 @@ class VerifiableReward:
|
||||
# Loop through reward functions
|
||||
for reward_fn in self.reward_fns:
|
||||
# Apply the reward function to the entire batch at once
|
||||
reward_batch = torch.stack(
|
||||
[
|
||||
reward_fn(
|
||||
input_ids[i],
|
||||
gt_answer=gt_answer[i],
|
||||
response_idx=response_idx[i],
|
||||
**self.kwargs,
|
||||
)
|
||||
for i in range(bs)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
if "gt_answer" in inspect.getfullargspec(reward_fn).args:
|
||||
reward_batch = torch.stack(
|
||||
[
|
||||
reward_fn(
|
||||
input_ids[i],
|
||||
gt_answer=gt_answer[i],
|
||||
response_idx=response_idx[i],
|
||||
**self.kwargs,
|
||||
)
|
||||
for i in range(bs)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
elif "test_cases" in inspect.getfullargspec(reward_fn).args:
|
||||
reward_batch = torch.stack(
|
||||
[
|
||||
reward_fn(
|
||||
input_ids[i],
|
||||
test_cases=test_cases[i],
|
||||
response_idx=response_idx[i],
|
||||
**self.kwargs,
|
||||
)
|
||||
for i in range(bs)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
else:
|
||||
reward_batch = torch.stack(
|
||||
[
|
||||
reward_fn(
|
||||
input_ids[i],
|
||||
response_idx=response_idx[i],
|
||||
**self.kwargs,
|
||||
)
|
||||
for i in range(bs)
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
rewards += reward_batch
|
||||
return rewards
|
||||
|
Reference in New Issue
Block a user