agpo reward

This commit is contained in:
Chen Li 2025-05-13 18:36:13 +08:00
parent e08626d740
commit a5380d7073
2 changed files with 82 additions and 0 deletions

View File

@ -0,0 +1,57 @@
"""
Function-based reward verification module.
"""
from typing import Any, Dict, List
import torch
class AGPOReward:
def __init__(self, reward_fn: callable, **kwargs: List[Dict[str, Any]]):
self.reward_fn = reward_fn
self.kwargs = kwargs
def __call__(
self,
input_ids: torch.LongTensor,
gt_answer: List[torch.Tensor] = None,
response_idx: List[torch.Tensor] = None,
) -> torch.Tensor:
# Get batch size
bs = input_ids.size(0)
# Initialize reward
rewards = torch.zeros((bs, 3), device=input_ids.device)
len_rewards = torch.zeros((bs), device=input_ids.device)
num_generations = self.kwargs.get("num_generations")
# Apply the reward function to the entire batch at once
reward_infos = [self.reward_fn(input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], **self.kwargs) for i in range(bs)]
reward_batch = torch.stack([info[0] for info in reward_infos])
format_acc_batch = torch.stack([info[1] for info in reward_infos])
ans_acc_batch = torch.stack([info[2] for info in reward_infos])
seq_len_batch = torch.stack([info[3] for info in reward_infos])
group_reward = reward_batch.view(-1, num_generations)
reward_std = group_reward.std(dim=1).repeat_interleave(num_generations, dim=0)
mask_zero_std = reward_std == 0
group_seq_len = seq_len_batch.view(-1, num_generations)
group_ans_acc = ans_acc_batch.view(-1, num_generations)
mask_incorrect = group_ans_acc == 0
masked_seq_len_for_max = group_seq_len.masked_fill(mask_incorrect, float('-inf'))
max_group_seq_len = masked_seq_len_for_max.max(dim=1).repeat_interleave(num_generations, dim=0)
masked_seq_len_for_min = group_seq_len.masked_fill(mask_incorrect, float('inf'))
min_group_seq_len = masked_seq_len_for_min.min(dim=1).repeat_interleave(num_generations, dim=0)
len_ratio = (seq_len_batch - min_group_seq_len) / (max_group_seq_len - min_group_seq_len + 1e-6)
len_rewards = 0.1 * (1 - len_ratio)
len_rewards = len_rewards.masked_fill(mask_zero_std, 0.0)
len_rewards = len_rewards.masked_fill(mask_incorrect, 0.0)
reward_batch += len_rewards
rewards += torch.stack(
[torch.tensor([r, f, a]).to(input_ids.device) for r, f, a in zip(reward_batch, format_acc_batch, ans_acc_batch)],
dim=0
)
return rewards

View File

@ -110,3 +110,28 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
reward = reward + length_reward
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
def agpo_boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
tokenizer = kwargs["tokenizer"]
reward_correct = torch.tensor(1.0)
reward_wrong = torch.tensor(-1.0)
ans_acc = torch.tensor(0.0)
format_acc = torch.tensor(0.0)
s, e = response_idx[0], response_idx[1]
seq_len = torch.tensor(e - s + 1)
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
final_answer = extract_boxed_solution(decoded_final_answer)
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
if final_answer is not None:
format_acc += 1
if gt_answer.strip().lower() == final_answer.strip().lower():
ans_acc += 1
reward = reward_correct
else:
reward = reward_wrong
return torch.tensor([reward, format_acc, ans_acc, seq_len]).to(input_ids.device)