fix reward taging bug

This commit is contained in:
YeAnbang
2025-05-03 14:34:04 +08:00
parent da867a4d8f
commit 2999bd4cc8
4 changed files with 29 additions and 32 deletions

View File

@@ -149,9 +149,13 @@ class GRPOConsumer(BaseConsumer):
def setup(self):
super().setup()
if self.use_wandb and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
if (
self.use_wandb
and self.dp_rank == 0
and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
)
):
# Initialize wandb.
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
@@ -482,8 +486,13 @@ class GRPOConsumer(BaseConsumer):
if not self.plugin.pp_size > 1 or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
):
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
if self.dp_rank == 0 and (
(not self.plugin.pp_size > 1 and self.rank == 0)
or (
self.plugin.pp_size > 1
and self.booster.plugin.stage_manager.is_last_stage()
and self.tp_rank == 0
)
):
to_log_msg = [
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",