diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 15f7e340e..ae9f2c400 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -56,6 +56,9 @@ class GRPOConsumer(BaseConsumer): self.accum_loss = torch.zeros(1, device=self.device) self.accum_reward = torch.zeros(1, device=self.device) self.accum_kl = torch.zeros(1, device=self.device) + self.accum_format_reward = torch.zeros(1, device=self.device) + self.accum_acc_reward = torch.zeros(1, device=self.device) + self.accum_advantages = torch.zeros(1, device=self.device) self.accum_count = 0 # Reference model is initialized from policy model. @@ -131,9 +134,14 @@ class GRPOConsumer(BaseConsumer): ) kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1) - reward = self.reward_model( + reward_group = self.reward_model( data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"] ) + + reward = torch.tensor([value[0] for value in reward_group]).to(data["input_ids"].device) + format_reward = torch.tensor([value[1] for value in reward_group]).to(data["input_ids"].device) + acc_reward = torch.tensor([value[2] for value in reward_group]).to(data["input_ids"].device) + # [batch_size, num_generations] group_reward = reward.view(-1, self.num_generations) @@ -157,9 +165,16 @@ class GRPOConsumer(BaseConsumer): loss = all_reduce_mean(loss, self.plugin) reward = all_reduce_mean(reward.mean(), self.plugin) kl = all_reduce_mean(kl.mean(), self.plugin) + format_reward = all_reduce_mean(format_reward.mean(), self.plugin) + acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) + advantages = all_reduce_mean(advantages.mean(), self.plugin) + # Calculate accumulate value. self.accum_loss.add_(loss.data) self.accum_reward.add_(reward.data) self.accum_kl.add_(kl.data) + self.accum_format_reward.add_(format_reward.data) + self.accum_acc_reward.add_(acc_reward.data) + self.accum_advantages.add_(advantages.data) self.accum_count += 1 if need_update: self.optimizer.step() @@ -173,17 +188,28 @@ class GRPOConsumer(BaseConsumer): self.accum_reward.item() / self.accum_count, "KL:", self.accum_kl.item() / self.accum_count, + "Format Reward:", + self.accum_format_reward.item() / self.accum_count, + "Acc Reward:", + self.accum_acc_reward.item() / self.accum_count, + "Advantages:", + self.accum_advantages.item() / self.accum_count, ) self.wandb_run.log( { "train/loss": self.accum_loss.item() / self.accum_count, "train/reward": self.accum_reward.item() / self.accum_count, "train/kl": self.accum_kl.item() / self.accum_count, + "train/format_reward": self.accum_format_reward.item() / self.accum_count, + "train/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "train/advantages": self.accum_advantages.item() / self.accum_count, } ) self.accum_loss.zero_() self.accum_reward.zero_() self.accum_kl.zero_() + self.accum_acc_reward.zero_() + self.accum_format_reward.zero_() self.accum_count = 0 return loss_scalar