From d19f1f21b69a15c9fbb89597a85479baa12fa291 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Thu, 15 May 2025 18:30:32 +0800 Subject: [PATCH] move prompt-level-filtering to buffer side --- applications/ColossalChat/coati/distributed/grpo_consumer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index fcf7b0740..e709c8aed 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -254,7 +254,6 @@ class GRPOConsumer(BaseConsumer): total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() - self.total_sample_count += total_samples.item() pbar.set_postfix( { "Global Step": self.global_step, @@ -461,6 +460,9 @@ class GRPOConsumer(BaseConsumer): self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 + self.total_sample_count = all_reduce_sum( + torch.tensor(self.total_sample_count).to(self.accum_loss.device), self.plugin + ).item() sample_utilization = self.effective_sample_count / self.total_sample_count self.effective_prompt_count = 0 self.effective_sample_count = 0 @@ -564,6 +566,7 @@ class GRPOConsumer(BaseConsumer): "format_acc": torch.Tensor, [num_of_generation] "ans_acc": torch.Tensor, [num_of_generation] """ + self.total_sample_count += rollout_group["input_ids"].size(0) if self.filter_range is not None: # filter prompt whoes accuracy is too high or too low (out of range) group_ans_acc = torch.mean(rollout_group["ans_acc"])