From c8b368c294a76fd06b45b5f683cde466e03f987c Mon Sep 17 00:00:00 2001 From: Tong Li Date: Wed, 28 May 2025 19:18:09 +0800 Subject: [PATCH] add overlength sample count (#6332) Co-authored-by: Tong Li --- .../coati/distributed/grpo_consumer.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index eaf3521b6..1666ac582 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -85,6 +85,8 @@ class GRPOConsumer(BaseConsumer): self.effective_sample_count = 0 self.effective_prompt_count = 0 self.total_sample_count = 0 + self.overlength_samples = 0 + self.total_overlength_samples = 0 self.project_name = project_name self.run_name = run_name self.wandb_group_name = wandb_group_name @@ -227,10 +229,18 @@ class GRPOConsumer(BaseConsumer): # filter out overlength samples if self.filter_truncated_response and action_mask.size(1) == self.max_length: + old_loss_mask = loss_mask.clone() loss_mask = torch.logical_and( loss_mask, action_mask[:, -1] == False, ) + + self.overlength_samples = (old_loss_mask & ~loss_mask).sum().item() + self.overlength_samples = all_reduce_sum( + torch.tensor(self.overlength_samples, device=loss_mask.device), self.plugin + ) + self.total_overlength_samples += self.overlength_samples.item() + prompt_level_mask = loss_mask.view(self.minibatch_size, self.num_generations) # [minibatch_size] -> calculate the number of effective prompts @@ -484,9 +494,11 @@ class GRPOConsumer(BaseConsumer): self.optimizer.zero_grad() self.global_step += 1 sample_utilization = self.effective_sample_count / self.total_sample_count + overlength_samples_percentage = self.total_overlength_samples / self.total_sample_count self.effective_prompt_count = 0 self.effective_sample_count = 0 self.total_sample_count = 0 + self.total_overlength_samples = 0 loss_scalar = self.accum_loss.item() if not self.plugin.pp_size > 1 or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 @@ -502,6 +514,7 @@ class GRPOConsumer(BaseConsumer): f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", f"Sample_utilization: {sample_utilization:.4f}", + f"Percentage of overlength samples: {overlength_samples_percentage:.4f}", ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { @@ -513,6 +526,7 @@ class GRPOConsumer(BaseConsumer): "train/advantages": self.accum_advantages.item() / self.accum_count, "train/learning_rate": self.lr_scheduler.get_last_lr()[0], "train/sample_utilization": sample_utilization, + "train/percentage_overlength_samples": overlength_samples_percentage, "rollout/temperature": data["temperature"].cpu().numpy()[0][0], } if self.policy_loss_fn.beta > 0: