From 8880b83791fa9d34d733f365236d8b7beabdb0ee Mon Sep 17 00:00:00 2001 From: Tong Li Date: Thu, 19 Jun 2025 14:02:08 +0800 Subject: [PATCH] add dp rank for multi-dp (#6351) Co-authored-by: Tong Li --- .../ColossalChat/coati/distributed/grpo_consumer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index d7d6221a1..8d50734a9 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -130,7 +130,10 @@ class GRPOConsumer(BaseConsumer): def setup(self): super().setup() if (not self.plugin.pp_size > 1 and self.rank == 0) or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + and self.dp_rank == 0 ): self.wandb_run = wandb.init( project=self.project_name, @@ -222,7 +225,6 @@ class GRPOConsumer(BaseConsumer): effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin) effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin) - total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin) self.effective_sample_count += effective_samples.item() pbar.set_postfix( { @@ -407,7 +409,10 @@ class GRPOConsumer(BaseConsumer): mean_kl.append(kl.data) mean_loss.append(loss.data) if not self.plugin.pp_size > 1 or ( - self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 + self.plugin.pp_size > 1 + and self.booster.plugin.stage_manager.is_last_stage() + and self.tp_rank == 0 + and self.dp_rank == 0 ): reward = all_reduce_mean(reward.mean(), self.plugin) format_acc = all_reduce_mean(format_acc.mean(), self.plugin)