mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-06-27 07:47:05 +00:00
update grpo
This commit is contained in:
parent
eb6337f07f
commit
9d9d51614e
@ -56,6 +56,9 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_loss = torch.zeros(1, device=self.device)
|
self.accum_loss = torch.zeros(1, device=self.device)
|
||||||
self.accum_reward = torch.zeros(1, device=self.device)
|
self.accum_reward = torch.zeros(1, device=self.device)
|
||||||
self.accum_kl = torch.zeros(1, device=self.device)
|
self.accum_kl = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_format_reward = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_acc_reward = torch.zeros(1, device=self.device)
|
||||||
|
self.accum_advantages = torch.zeros(1, device=self.device)
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
|
|
||||||
# Reference model is initialized from policy model.
|
# Reference model is initialized from policy model.
|
||||||
@ -131,9 +134,14 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
)
|
)
|
||||||
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
||||||
|
|
||||||
reward = self.reward_model(
|
reward_group = self.reward_model(
|
||||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device)
|
||||||
|
format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device)
|
||||||
|
acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device)
|
||||||
|
|
||||||
# [batch_size, num_generations]
|
# [batch_size, num_generations]
|
||||||
group_reward = reward.view(-1, self.num_generations)
|
group_reward = reward.view(-1, self.num_generations)
|
||||||
|
|
||||||
@ -157,9 +165,16 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
loss = all_reduce_mean(loss, self.plugin)
|
loss = all_reduce_mean(loss, self.plugin)
|
||||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||||
|
format_reward = all_reduce_mean(format_reward.mean(), self.plugin)
|
||||||
|
acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin)
|
||||||
|
advantages = all_reduce_mean(advantages.mean(), self.plugin)
|
||||||
|
# Calculate accumulate value.
|
||||||
self.accum_loss.add_(loss.data)
|
self.accum_loss.add_(loss.data)
|
||||||
self.accum_reward.add_(reward.data)
|
self.accum_reward.add_(reward.data)
|
||||||
self.accum_kl.add_(kl.data)
|
self.accum_kl.add_(kl.data)
|
||||||
|
self.accum_format_reward.add_(format_reward.data)
|
||||||
|
self.accum_acc_reward.add_(acc_reward.data)
|
||||||
|
self.accum_advantages.add_(advantages.data)
|
||||||
self.accum_count += 1
|
self.accum_count += 1
|
||||||
if need_update:
|
if need_update:
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
@ -173,17 +188,28 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
self.accum_reward.item() / self.accum_count,
|
self.accum_reward.item() / self.accum_count,
|
||||||
"KL:",
|
"KL:",
|
||||||
self.accum_kl.item() / self.accum_count,
|
self.accum_kl.item() / self.accum_count,
|
||||||
|
"Format Reward:",
|
||||||
|
self.accum_format_reward.item() / self.accum_count,
|
||||||
|
"Acc Reward:",
|
||||||
|
self.accum_acc_reward.item() / self.accum_count,
|
||||||
|
"Advantages:",
|
||||||
|
self.accum_advantages.item() / self.accum_count,
|
||||||
)
|
)
|
||||||
self.wandb_run.log(
|
self.wandb_run.log(
|
||||||
{
|
{
|
||||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||||
"train/reward": self.accum_reward.item() / self.accum_count,
|
"train/reward": self.accum_reward.item() / self.accum_count,
|
||||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||||
|
"train/format_reward": self.accum_format_reward.item() / self.accum_count,
|
||||||
|
"train/acc_reward": self.accum_acc_reward.item() / self.accum_count,
|
||||||
|
"train/advantages": self.accum_advantages.item() / self.accum_count,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
self.accum_loss.zero_()
|
self.accum_loss.zero_()
|
||||||
self.accum_reward.zero_()
|
self.accum_reward.zero_()
|
||||||
self.accum_kl.zero_()
|
self.accum_kl.zero_()
|
||||||
|
self.accum_acc_reward.zero_()
|
||||||
|
self.accum_format_reward.zero_()
|
||||||
self.accum_count = 0
|
self.accum_count = 0
|
||||||
return loss_scalar
|
return loss_scalar
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user