mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-30 15:00:39 +00:00
update agpo reward
This commit is contained in:
parent
a5380d7073
commit
13c2676612
@ -20,9 +20,6 @@ class AGPOReward:
|
||||
) -> torch.Tensor:
|
||||
# Get batch size
|
||||
bs = input_ids.size(0)
|
||||
# Initialize reward
|
||||
rewards = torch.zeros((bs, 3), device=input_ids.device)
|
||||
len_rewards = torch.zeros((bs), device=input_ids.device)
|
||||
num_generations = self.kwargs.get("num_generations")
|
||||
|
||||
# Apply the reward function to the entire batch at once
|
||||
@ -32,26 +29,24 @@ class AGPOReward:
|
||||
ans_acc_batch = torch.stack([info[2] for info in reward_infos])
|
||||
seq_len_batch = torch.stack([info[3] for info in reward_infos])
|
||||
|
||||
# calculate mask
|
||||
group_reward = reward_batch.view(-1, num_generations)
|
||||
reward_std = group_reward.std(dim=1).repeat_interleave(num_generations, dim=0)
|
||||
mask_zero_std = reward_std == 0
|
||||
mask = (reward_std == 0) | (ans_acc_batch == 0)
|
||||
|
||||
# process group seq len
|
||||
group_seq_len = seq_len_batch.view(-1, num_generations)
|
||||
group_ans_acc = ans_acc_batch.view(-1, num_generations)
|
||||
mask_incorrect = group_ans_acc == 0
|
||||
masked_seq_len_for_max = group_seq_len.masked_fill(mask_incorrect, float('-inf'))
|
||||
group_mask = mask.view(-1, num_generations)
|
||||
masked_seq_len_for_max = group_seq_len.masked_fill(group_mask, -1e6)
|
||||
max_group_seq_len = masked_seq_len_for_max.max(dim=1).repeat_interleave(num_generations, dim=0)
|
||||
masked_seq_len_for_min = group_seq_len.masked_fill(mask_incorrect, float('inf'))
|
||||
masked_seq_len_for_min = group_seq_len.masked_fill(group_mask, 1e6)
|
||||
min_group_seq_len = masked_seq_len_for_min.min(dim=1).repeat_interleave(num_generations, dim=0)
|
||||
|
||||
# correct sample length reward
|
||||
len_ratio = (seq_len_batch - min_group_seq_len) / (max_group_seq_len - min_group_seq_len + 1e-6)
|
||||
len_rewards = 0.1 * (1 - len_ratio)
|
||||
|
||||
len_rewards = len_rewards.masked_fill(mask_zero_std, 0.0)
|
||||
len_rewards = len_rewards.masked_fill(mask_incorrect, 0.0)
|
||||
len_rewards = len_rewards.masked_fill(mask, 0.0)
|
||||
reward_batch += len_rewards
|
||||
|
||||
rewards += torch.stack(
|
||||
[torch.tensor([r, f, a]).to(input_ids.device) for r, f, a in zip(reward_batch, format_acc_batch, ans_acc_batch)],
|
||||
dim=0
|
||||
)
|
||||
rewards = torch.stack([torch.stack([r, f, a]) for r, f, a in zip(reward_batch, format_acc_batch, ans_acc_batch)])
|
||||
return rewards
|
||||
|
Loading…
Reference in New Issue
Block a user