From 13c2676612ba1d96bb8b419889ebd5572901b9e5 Mon Sep 17 00:00:00 2001 From: Chen Li Date: Wed, 14 May 2025 16:54:52 +0800 Subject: [PATCH] update agpo reward --- .../coati/distributed/reward/agpo_reward.py | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/reward/agpo_reward.py b/applications/ColossalChat/coati/distributed/reward/agpo_reward.py index 648de9e3e..ae2b01756 100644 --- a/applications/ColossalChat/coati/distributed/reward/agpo_reward.py +++ b/applications/ColossalChat/coati/distributed/reward/agpo_reward.py @@ -20,9 +20,6 @@ class AGPOReward: ) -> torch.Tensor: # Get batch size bs = input_ids.size(0) - # Initialize reward - rewards = torch.zeros((bs, 3), device=input_ids.device) - len_rewards = torch.zeros((bs), device=input_ids.device) num_generations = self.kwargs.get("num_generations") # Apply the reward function to the entire batch at once @@ -32,26 +29,24 @@ class AGPOReward: ans_acc_batch = torch.stack([info[2] for info in reward_infos]) seq_len_batch = torch.stack([info[3] for info in reward_infos]) + # calculate mask group_reward = reward_batch.view(-1, num_generations) reward_std = group_reward.std(dim=1).repeat_interleave(num_generations, dim=0) - mask_zero_std = reward_std == 0 + mask = (reward_std == 0) | (ans_acc_batch == 0) + # process group seq len group_seq_len = seq_len_batch.view(-1, num_generations) - group_ans_acc = ans_acc_batch.view(-1, num_generations) - mask_incorrect = group_ans_acc == 0 - masked_seq_len_for_max = group_seq_len.masked_fill(mask_incorrect, float('-inf')) + group_mask = mask.view(-1, num_generations) + masked_seq_len_for_max = group_seq_len.masked_fill(group_mask, -1e6) max_group_seq_len = masked_seq_len_for_max.max(dim=1).repeat_interleave(num_generations, dim=0) - masked_seq_len_for_min = group_seq_len.masked_fill(mask_incorrect, float('inf')) + masked_seq_len_for_min = group_seq_len.masked_fill(group_mask, 1e6) min_group_seq_len = masked_seq_len_for_min.min(dim=1).repeat_interleave(num_generations, dim=0) + + # correct sample length reward len_ratio = (seq_len_batch - min_group_seq_len) / (max_group_seq_len - min_group_seq_len + 1e-6) len_rewards = 0.1 * (1 - len_ratio) - - len_rewards = len_rewards.masked_fill(mask_zero_std, 0.0) - len_rewards = len_rewards.masked_fill(mask_incorrect, 0.0) + len_rewards = len_rewards.masked_fill(mask, 0.0) reward_batch += len_rewards - rewards += torch.stack( - [torch.tensor([r, f, a]).to(input_ids.device) for r, f, a in zip(reward_batch, format_acc_batch, ans_acc_batch)], - dim=0 - ) + rewards = torch.stack([torch.stack([r, f, a]) for r, f, a in zip(reward_batch, format_acc_batch, ans_acc_batch)]) return rewards