update grpo

This commit is contained in:
Tong Li 2025-03-10 14:12:04 +08:00
parent eb6337f07f
commit 9d9d51614e

View File

@ -56,6 +56,9 @@ class GRPOConsumer(BaseConsumer):
self.accum_loss = torch.zeros(1, device=self.device)
self.accum_reward = torch.zeros(1, device=self.device)
self.accum_kl = torch.zeros(1, device=self.device)
self.accum_format_reward = torch.zeros(1, device=self.device)
self.accum_acc_reward = torch.zeros(1, device=self.device)
self.accum_advantages = torch.zeros(1, device=self.device)
self.accum_count = 0
# Reference model is initialized from policy model.
@ -131,9 +134,14 @@ class GRPOConsumer(BaseConsumer):
)
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
reward = self.reward_model(
reward_group = self.reward_model(
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
)
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
# [batch_size, num_generations]
group_reward = reward.view(-1, self.num_generations)
@ -157,9 +165,16 @@ class GRPOConsumer(BaseConsumer):
loss = all_reduce_mean(loss, self.plugin)
reward = all_reduce_mean(reward.mean(), self.plugin)
kl = all_reduce_mean(kl.mean(), self.plugin)
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
advantages = all_reduce_mean(advantages.mean(), self.plugin)
# Calculate accumulate value.
self.accum_loss.add_(loss.data)
self.accum_reward.add_(reward.data)
self.accum_kl.add_(kl.data)
self.accum_format_reward.add_(format_reward.data)
self.accum_acc_reward.add_(acc_reward.data)
self.accum_advantages.add_(advantages.data)
self.accum_count += 1
if need_update:
self.optimizer.step()
@ -173,17 +188,28 @@ class GRPOConsumer(BaseConsumer):
self.accum_reward.item() / self.accum_count,
"KL:",
self.accum_kl.item() / self.accum_count,
"Format Reward:",
self.accum_format_reward.item() / self.accum_count,
"Acc Reward:",
self.accum_acc_reward.item() / self.accum_count,
"Advantages:",
self.accum_advantages.item() / self.accum_count,
)
self.wandb_run.log(
{
"train/loss": self.accum_loss.item() / self.accum_count,
"train/reward": self.accum_reward.item() / self.accum_count,
"train/kl": self.accum_kl.item() / self.accum_count,
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
"train/advantages": self.accum_advantages.item() / self.accum_count,
}
)
self.accum_loss.zero_()
self.accum_reward.zero_()
self.accum_kl.zero_()
self.accum_acc_reward.zero_()
self.accum_format_reward.zero_()
self.accum_count = 0
return loss_scalar