mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-08-10 04:18:05 +00:00
agpo reward
This commit is contained in:
parent
e08626d740
commit
a5380d7073
@ -0,0 +1,57 @@
|
||||
"""
|
||||
Function-based reward verification module.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AGPOReward:
|
||||
def __init__(self, reward_fn: callable, **kwargs: List[Dict[str, Any]]):
|
||||
self.reward_fn = reward_fn
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
gt_answer: List[torch.Tensor] = None,
|
||||
response_idx: List[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Get batch size
|
||||
bs = input_ids.size(0)
|
||||
# Initialize reward
|
||||
rewards = torch.zeros((bs, 3), device=input_ids.device)
|
||||
len_rewards = torch.zeros((bs), device=input_ids.device)
|
||||
num_generations = self.kwargs.get("num_generations")
|
||||
|
||||
# Apply the reward function to the entire batch at once
|
||||
reward_infos = [self.reward_fn(input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], **self.kwargs) for i in range(bs)]
|
||||
reward_batch = torch.stack([info[0] for info in reward_infos])
|
||||
format_acc_batch = torch.stack([info[1] for info in reward_infos])
|
||||
ans_acc_batch = torch.stack([info[2] for info in reward_infos])
|
||||
seq_len_batch = torch.stack([info[3] for info in reward_infos])
|
||||
|
||||
group_reward = reward_batch.view(-1, num_generations)
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(num_generations, dim=0)
|
||||
mask_zero_std = reward_std == 0
|
||||
|
||||
group_seq_len = seq_len_batch.view(-1, num_generations)
|
||||
group_ans_acc = ans_acc_batch.view(-1, num_generations)
|
||||
mask_incorrect = group_ans_acc == 0
|
||||
masked_seq_len_for_max = group_seq_len.masked_fill(mask_incorrect, float('-inf'))
|
||||
max_group_seq_len = masked_seq_len_for_max.max(dim=1).repeat_interleave(num_generations, dim=0)
|
||||
masked_seq_len_for_min = group_seq_len.masked_fill(mask_incorrect, float('inf'))
|
||||
min_group_seq_len = masked_seq_len_for_min.min(dim=1).repeat_interleave(num_generations, dim=0)
|
||||
len_ratio = (seq_len_batch - min_group_seq_len) / (max_group_seq_len - min_group_seq_len + 1e-6)
|
||||
len_rewards = 0.1 * (1 - len_ratio)
|
||||
|
||||
len_rewards = len_rewards.masked_fill(mask_zero_std, 0.0)
|
||||
len_rewards = len_rewards.masked_fill(mask_incorrect, 0.0)
|
||||
reward_batch += len_rewards
|
||||
|
||||
rewards += torch.stack(
|
||||
[torch.tensor([r, f, a]).to(input_ids.device) for r, f, a in zip(reward_batch, format_acc_batch, ans_acc_batch)],
|
||||
dim=0
|
||||
)
|
||||
return rewards
|
@ -110,3 +110,28 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
reward = reward + length_reward
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device)
|
||||
|
||||
|
||||
def agpo_boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
||||
tokenizer = kwargs["tokenizer"]
|
||||
reward_correct = torch.tensor(1.0)
|
||||
reward_wrong = torch.tensor(-1.0)
|
||||
ans_acc = torch.tensor(0.0)
|
||||
format_acc = torch.tensor(0.0)
|
||||
s, e = response_idx[0], response_idx[1]
|
||||
seq_len = torch.tensor(e - s + 1)
|
||||
|
||||
decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True)
|
||||
gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True)
|
||||
final_answer = extract_boxed_solution(decoded_final_answer)
|
||||
|
||||
# Check answer accuracy, answer is considered correct if the answer is correct and the format is valid
|
||||
if final_answer is not None:
|
||||
format_acc += 1
|
||||
if gt_answer.strip().lower() == final_answer.strip().lower():
|
||||
ans_acc += 1
|
||||
reward = reward_correct
|
||||
else:
|
||||
reward = reward_wrong
|
||||
|
||||
return torch.tensor([reward, format_acc, ans_acc, seq_len]).to(input_ids.device)
|
||||
|
Loading…
Reference in New Issue
Block a user