add dp rank for multi-dp (#6351)

Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
Tong Li 2025-06-19 14:02:08 +08:00 committed by GitHub
parent dd49444dcb
commit 8880b83791
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -130,7 +130,10 @@ class GRPOConsumer(BaseConsumer):
def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
and self.dp_rank == 0
):
self.wandb_run = wandb.init(
project=self.project_name,
@ -222,7 +225,6 @@ class GRPOConsumer(BaseConsumer):
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
self.effective_sample_count += effective_samples.item()
pbar.set_postfix(
{
@ -407,7 +409,10 @@ class GRPOConsumer(BaseConsumer):
mean_kl.append(kl.data)
mean_loss.append(loss.data)
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
and self.dp_rank == 0
):
reward = all_reduce_mean(reward.mean(), self.plugin)
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)