diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py index 831612f80..d4589b0e2 100644 --- a/applications/ColossalChat/coati/distributed/grpo_consumer.py +++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py @@ -474,19 +474,14 @@ class GRPOConsumer(BaseConsumer): if (not self.plugin.pp_size > 1 and self.rank == 0) or ( self.plugin.pp_size > 1 and self.booster.plugin.stage_manager.is_last_stage() and self.tp_rank == 0 ): - to_log_msg = ( - [ - f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", - f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", - f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", - f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", - f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", - f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", - ] - + [f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] - if self.policy_loss_fn.beta > 0 - else [] - ) + to_log_msg = [ + f"Loss: {self.accum_loss.item() / self.accum_count:.4f}", + f"Reward: {self.accum_reward.item() / self.accum_count:.4f}", + f"format Reward: {self.accum_format_acc.item() / self.accum_count:.4f}", + f"Acc Reward: {self.accum_ans_acc.item() / self.accum_count:.4f}", + f"Advantages: {self.accum_advantages.item() / self.accum_count:.4f}", + f"Response Length: {self.accum_response_length.item() / self.accum_count:.4f}", + ] + ([f"KL: {self.accum_kl.item() / self.accum_count:.4f}"] if self.policy_loss_fn.beta > 0 else []) print("\n".join(to_log_msg)) metrics = { "metrics/reward": self.accum_reward.item() / self.accum_count,