mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-26 07:22:12 +00:00
update grpo
This commit is contained in:
parent
eb6337f07f
commit
9d9d51614e
@ -56,6 +56,9 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
@ -131,9 +134,14 @@ class GRPOConsumer(BaseConsumer):
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
||||
|
||||
reward = self.reward_model(
|
||||
reward_group = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
|
||||
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
||||
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
||||
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||
|
||||
# [batch_size, num_generations]
|
||||
group_reward = reward.view(-1, self.num_generations)
|
||||
|
||||
@ -157,9 +165,16 @@ class GRPOConsumer(BaseConsumer):
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||
# Calculate accumulate value.
|
||||
self.accum_loss.add_(loss.data)
|
||||
self.accum_reward.add_(reward.data)
|
||||
self.accum_kl.add_(kl.data)
|
||||
self.accum_format_reward.add_(format_reward.data)
|
||||
self.accum_acc_reward.add_(acc_reward.data)
|
||||
self.accum_advantages.add_(advantages.data)
|
||||
self.accum_count += 1
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
@ -173,17 +188,28 @@ class GRPOConsumer(BaseConsumer):
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"KL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
"Format Reward:",
|
||||
self.accum_format_reward.item() / self.accum_count,
|
||||
"Acc Reward:",
|
||||
self.accum_acc_reward.item() / self.accum_count,
|
||||
"Advantages:",
|
||||
self.accum_advantages.item() / self.accum_count,
|
||||
)
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/reward": self.accum_reward.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||
}
|
||||
)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_acc_reward.zero_()
|
||||
self.accum_format_reward.zero_()
|
||||
self.accum_count = 0
|
||||
return loss_scalar
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user