From 47d64937782b4e2b0d21b84eef07b5e258b82abb Mon Sep 17 00:00:00 2001 From: Tong Li Date: Tue, 11 Mar 2025 13:06:09 +0800 Subject: [PATCH] add response length --- .../coati/distributed/grpo_consumer.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index ae9f2c400..55dfd09ab 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -59,6 +59,7 @@ class GRPOConsumer(BaseConsumer): self.accum_format_reward = torch.zeros(1, device=self.device) self.accum_acc_reward = torch.zeros(1, device=self.device) self.accum_advantages = torch.zeros(1, device=self.device) + self.accum_response_length = torch.zeros(1, device=self.device) self.accum_count = 0 # Reference model is initialized from policy model. @@ -83,7 +84,7 @@ class GRPOConsumer(BaseConsumer): self.policy_loss_fn = PolicyLoss() self.global_step = 0 if use_wandb and self.rank == 0: - self.wandb_run = wandb.init(project="GRPO-Test", sync_tensorboard=True) + self.wandb_run = wandb.init(project="GRPO-V1", sync_tensorboard=True) def setup(self): super().setup() @@ -109,6 +110,7 @@ class GRPOConsumer(BaseConsumer): action_mask = data["action_mask"] num_action = action_mask.shape[1] old_action_log_probs = data["action_log_probs"] + response_length = torch.sum(action_mask, dim=1).to(torch.float32) need_update = (step_idx + 1) % self.num_microbatches == 0 @@ -168,6 +170,7 @@ class GRPOConsumer(BaseConsumer): format_reward = all_reduce_mean(format_reward.mean(), self.plugin) acc_reward = all_reduce_mean(acc_reward.mean(), self.plugin) advantages = all_reduce_mean(advantages.mean(), self.plugin) + response_length = all_reduce_mean(response_length.mean(), self.plugin) # Calculate accumulate value. self.accum_loss.add_(loss.data) self.accum_reward.add_(reward.data) @@ -175,6 +178,7 @@ class GRPOConsumer(BaseConsumer): self.accum_format_reward.add_(format_reward.data) self.accum_acc_reward.add_(acc_reward.data) self.accum_advantages.add_(advantages.data) + self.accum_response_length.add_(response_length.data) self.accum_count += 1 if need_update: self.optimizer.step() @@ -184,32 +188,38 @@ class GRPOConsumer(BaseConsumer): print( "Loss:", self.accum_loss.item() / self.accum_count, - "Reward:", + "\nReward:", self.accum_reward.item() / self.accum_count, - "KL:", - self.accum_kl.item() / self.accum_count, - "Format Reward:", + "\nFormat Reward:", self.accum_format_reward.item() / self.accum_count, - "Acc Reward:", + "\nAcc Reward:", self.accum_acc_reward.item() / self.accum_count, - "Advantages:", + "\nKL:", + self.accum_kl.item() / self.accum_count, + "\nAdvantages:", self.accum_advantages.item() / self.accum_count, + "\nResponse Length:", + self.accum_response_length.item() / self.accum_count, ) self.wandb_run.log( { "train/loss": self.accum_loss.item() / self.accum_count, "train/reward": self.accum_reward.item() / self.accum_count, - "train/kl": self.accum_kl.item() / self.accum_count, "train/format_reward": self.accum_format_reward.item() / self.accum_count, "train/acc_reward": self.accum_acc_reward.item() / self.accum_count, + "train/kl": self.accum_kl.item() / self.accum_count, "train/advantages": self.accum_advantages.item() / self.accum_count, + "train/response_length": self.accum_response_length.item() / self.accum_count, } ) self.accum_loss.zero_() self.accum_reward.zero_() - self.accum_kl.zero_() self.accum_acc_reward.zero_() self.accum_format_reward.zero_() + self.accum_kl.zero_() + self.accum_advantages.zero_() + self.accum_response_length.zero_() + self.accum_count = 0 return loss_scalar