mirror of
https://github.com/hpcaitech/ColossalAI.git
synced 2025-09-03 10:06:44 +00:00
fix reward taging bug
This commit is contained in:
@@ -149,9 +149,13 @@ class GRPOConsumer(BaseConsumer):
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
if self.use_wandb and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||
if (
|
||||
self.use_wandb
|
||||
and self.dp_rank == 0
|
||||
and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0)
|
||||
)
|
||||
):
|
||||
# Initialize wandb.
|
||||
name = f"{self.generate_config['backend']}_bs_{self.batch_size*self.dp_size}_temp_{self.generate_config['temperature']:.01f}_top_p_{self.generate_config['top_p']:.02f}"
|
||||
@@ -482,8 +486,13 @@ class GRPOConsumer(BaseConsumer):
|
||||
if not self.plugin.pp_size > 1 or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
):
|
||||
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
|
||||
self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0
|
||||
if self.dp_rank == 0 and (
|
||||
(not self.plugin.pp_size > 1 and self.rank == 0)
|
||||
or (
|
||||
self.plugin.pp_size > 1
|
||||
and self.booster.plugin.stage_manager.is_last_stage()
|
||||
and self.tp_rank == 0
|
||||
)
|
||||
):
|
||||
to_log_msg = [
|
||||
f"Loss: {self.accum_loss.item() / self.accum_count:.4f}",
|
||||
|
Reference in New Issue
Block a user