diff --git a/applications/ColossalChat/coati/distributed/reward/agpo_reward.py b/applications/ColossalChat/coati/distributed/reward/agpo_reward.py new file mode 100644 index 000000000..648de9e3e --- /dev/null +++ b/applications/ColossalChat/coati/distributed/reward/agpo_reward.py @@ -0,0 +1,57 @@ +""" +Function-based reward verification module. +""" + +from typing import Any, Dict, List + +import torch + + +class AGPOReward: + def __init__(self, reward_fn: callable, **kwargs: List[Dict[str, Any]]): + self.reward_fn = reward_fn + self.kwargs = kwargs + + def __call__( + self, + input_ids: torch.LongTensor, + gt_answer: List[torch.Tensor] = None, + response_idx: List[torch.Tensor] = None, + ) -> torch.Tensor: + # Get batch size + bs = input_ids.size(0) + # Initialize reward + rewards = torch.zeros((bs, 3), device=input_ids.device) + len_rewards = torch.zeros((bs), device=input_ids.device) + num_generations = self.kwargs.get("num_generations") + + # Apply the reward function to the entire batch at once + reward_infos = [self.reward_fn(input_ids[i], gt_answer=gt_answer[i], response_idx=response_idx[i], **self.kwargs) for i in range(bs)] + reward_batch = torch.stack([info[0] for info in reward_infos]) + format_acc_batch = torch.stack([info[1] for info in reward_infos]) + ans_acc_batch = torch.stack([info[2] for info in reward_infos]) + seq_len_batch = torch.stack([info[3] for info in reward_infos]) + + group_reward = reward_batch.view(-1, num_generations) + reward_std = group_reward.std(dim=1).repeat_interleave(num_generations, dim=0) + mask_zero_std = reward_std == 0 + + group_seq_len = seq_len_batch.view(-1, num_generations) + group_ans_acc = ans_acc_batch.view(-1, num_generations) + mask_incorrect = group_ans_acc == 0 + masked_seq_len_for_max = group_seq_len.masked_fill(mask_incorrect, float('-inf')) + max_group_seq_len = masked_seq_len_for_max.max(dim=1).repeat_interleave(num_generations, dim=0) + masked_seq_len_for_min = group_seq_len.masked_fill(mask_incorrect, float('inf')) + min_group_seq_len = masked_seq_len_for_min.min(dim=1).repeat_interleave(num_generations, dim=0) + len_ratio = (seq_len_batch - min_group_seq_len) / (max_group_seq_len - min_group_seq_len + 1e-6) + len_rewards = 0.1 * (1 - len_ratio) + + len_rewards = len_rewards.masked_fill(mask_zero_std, 0.0) + len_rewards = len_rewards.masked_fill(mask_incorrect, 0.0) + reward_batch += len_rewards + + rewards += torch.stack( + [torch.tensor([r, f, a]).to(input_ids.device) for r, f, a in zip(reward_batch, format_acc_batch, ans_acc_batch)], + dim=0 + ) + return rewards diff --git a/applications/ColossalChat/coati/distributed/reward/reward_fn.py b/applications/ColossalChat/coati/distributed/reward/reward_fn.py index b68c1a92f..198913b49 100644 --- a/applications/ColossalChat/coati/distributed/reward/reward_fn.py +++ b/applications/ColossalChat/coati/distributed/reward/reward_fn.py @@ -110,3 +110,28 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): reward = reward + length_reward return torch.tensor([reward, format_acc, ans_acc]).to(input_ids.device) + + +def agpo_boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs): + tokenizer = kwargs["tokenizer"] + reward_correct = torch.tensor(1.0) + reward_wrong = torch.tensor(-1.0) + ans_acc = torch.tensor(0.0) + format_acc = torch.tensor(0.0) + s, e = response_idx[0], response_idx[1] + seq_len = torch.tensor(e - s + 1) + + decoded_final_answer = tokenizer.decode(input_ids[s : e + 1], skip_special_tokens=True) + gt_answer = tokenizer.decode(gt_answer.squeeze(0), skip_special_tokens=True) + final_answer = extract_boxed_solution(decoded_final_answer) + + # Check answer accuracy, answer is considered correct if the answer is correct and the format is valid + if final_answer is not None: + format_acc += 1 + if gt_answer.strip().lower() == final_answer.strip().lower(): + ans_acc += 1 + reward = reward_correct + else: + reward = reward_wrong + + return torch.tensor([reward, format_acc, ans_acc, seq_len]).to(input_ids.device)