mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-07-28 14:00:34 +00:00
add dp rank for multi-dp (#6351)
Co-authored-by: Tong Li <tong.li35271158@gmail.com>
This commit is contained in:
parent
dd49444dcb
commit
8880b83791
@ -130,7 +130,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
def setup(self):
|
def setup(self):
|
||||||
super().setup()
|
super().setup()
|
||||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1
|
||||||
|
and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
|
and self.tp_rank == 0
|
||||||
|
and self.dp_rank == 0
|
||||||
):
|
):
|
||||||
self.wandb_run = wandb.init(
|
self.wandb_run = wandb.init(
|
||||||
project=self.project_name,
|
project=self.project_name,
|
||||||
@ -222,7 +225,6 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
effective_samples = all_reduce_sum(torch.sum(loss_mask), self.plugin)
|
||||||
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
effective_tokens_count = torch.sum(action_mask, dim=-1) * loss_mask
|
||||||
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
total_effective_tokens_count = all_reduce_sum(torch.sum(effective_tokens_count), self.plugin)
|
||||||
total_samples = all_reduce_sum(torch.sum(torch.ones_like(loss_mask, device=loss_mask.device)), self.plugin)
|
|
||||||
self.effective_sample_count += effective_samples.item()
|
self.effective_sample_count += effective_samples.item()
|
||||||
pbar.set_postfix(
|
pbar.set_postfix(
|
||||||
{
|
{
|
||||||
@ -407,7 +409,10 @@ class GRPOConsumer(BaseConsumer):
|
|||||||
mean_kl.append(kl.data)
|
mean_kl.append(kl.data)
|
||||||
mean_loss.append(loss.data)
|
mean_loss.append(loss.data)
|
||||||
if not self.plugin.pp_size > 1 or (
|
if not self.plugin.pp_size > 1 or (
|
||||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
self.plugin.pp_size > 1
|
||||||
|
and self.booster.plugin.stage_manager.is_last_stage()
|
||||||
|
and self.tp_rank == 0
|
||||||
|
and self.dp_rank == 0
|
||||||
):
|
):
|
||||||
reward = all_reduce_mean(reward.mean(), self.plugin)
|
reward = all_reduce_mean(reward.mean(), self.plugin)
|
||||||
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
|
format_acc = all_reduce_mean(format_acc.mean(), self.plugin)
|
||||||
|
Loading…
Reference in New Issue
Block a user